diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 0c4aa63669..3ae00a7ee8 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,25 @@ +3.29.9 +====== +March 18, 2026 + +Features +-------- +* Add Private Link support via client routes handler +* Add optional query_params parameter to QueryMessage + +Bug Fixes +--------- +* Fix segmentation fault in libev prepare_callback during shutdown +* Add null checks to io_callback and timer_callback in libev wrapper +* Fix RecursionError in execute_concurrent on synchronous errbacks +* Fix floating-point precision loss for timestamps far from epoch + +Others +------ +* Cache parsed tablet routing type in ResponseFuture +* Remove deprecated setup_requires in favor of PEP 517 build-system.requires +* Update dependency hatchling to v1.29.0 + 3.29.8 ====== February 09, 2026 diff --git a/benchmarks/bench_cache_apply_parameters.py b/benchmarks/bench_cache_apply_parameters.py new file mode 100644 index 0000000000..77cd4ef926 --- /dev/null +++ b/benchmarks/bench_cache_apply_parameters.py @@ -0,0 +1,78 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Micro-benchmark: apply_parameters caching. + +Measures the speedup from caching parameterized type creation +in _CassandraType.apply_parameters(). + +Run: + python benchmarks/bench_cache_apply_parameters.py +""" +import timeit +from cassandra.cqltypes import ( + MapType, SetType, ListType, TupleType, + Int32Type, UTF8Type, FloatType, DoubleType, BooleanType, + _CassandraType, +) + + +def bench_apply_parameters(): + """Benchmark apply_parameters with cache (repeated calls).""" + cache = _CassandraType._apply_parameters_cache + + # Warm up the cache + MapType.apply_parameters([UTF8Type, Int32Type]) + SetType.apply_parameters([FloatType]) + ListType.apply_parameters([DoubleType]) + TupleType.apply_parameters([Int32Type, UTF8Type, BooleanType]) + + calls = [ + (MapType, [UTF8Type, Int32Type]), + (SetType, [FloatType]), + (ListType, [DoubleType]), + (TupleType, [Int32Type, UTF8Type, BooleanType]), + ] + + def run_cached(): + for cls, subtypes in calls: + cls.apply_parameters(subtypes) + + # Benchmark cached path + n = 100_000 + t_cached = timeit.timeit(run_cached, number=n) + print(f"Cached apply_parameters ({len(calls)} types x {n} iters): " + f"{t_cached:.3f}s ({t_cached / (n * len(calls)) * 1e6:.2f} us/call)") + + # Benchmark uncached path (clear cache each iteration) + def run_uncached(): + for cls, subtypes in calls: + cache.clear() + cls.apply_parameters(subtypes) + + t_uncached = timeit.timeit(run_uncached, number=n) + print(f"Uncached apply_parameters ({len(calls)} types x {n} iters): " + f"{t_uncached:.3f}s ({t_uncached / (n * len(calls)) * 1e6:.2f} us/call)") + + speedup = t_uncached / t_cached + print(f"Speedup: {speedup:.1f}x") + + +def main(): + bench_apply_parameters() + + +if __name__ == '__main__': + main() diff --git a/cassandra/__init__.py b/cassandra/__init__.py index 5567c0b9bd..3ad8fcdfd1 100644 --- a/cassandra/__init__.py +++ b/cassandra/__init__.py @@ -23,7 +23,7 @@ def emit(self, record): logging.getLogger('cassandra').addHandler(NullHandler()) -__version_info__ = (3, 29, 8) +__version_info__ = (3, 29, 9) __version__ = '.'.join(map(str, __version_info__)) diff --git a/cassandra/client_routes.py b/cassandra/client_routes.py new file mode 100644 index 0000000000..80b2477a6d --- /dev/null +++ b/cassandra/client_routes.py @@ -0,0 +1,451 @@ +# Copyright 2026 ScyllaDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Client Routes support for Private Link and similar network configurations. + +This module implements support for dynamic address translation via the +system.client_routes table and CLIENT_ROUTES_CHANGE events. +""" + +from __future__ import absolute_import + +from dataclasses import dataclass +import enum +import logging +import socket +import threading +import uuid +from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple + +from cassandra import ConsistencyLevel +from cassandra.protocol import QueryMessage +from cassandra.query import dict_factory + +if TYPE_CHECKING: + from cassandra.connection import Connection + +log = logging.getLogger(__name__) + + +class ClientRoutesChangeType(enum.Enum): + """ + Types of CLIENT_ROUTES_CHANGE events. + + Currently the protocol defines only UPDATE_NODES. + New variants will be added here if the protocol is extended. + """ + UPDATE_NODES = "UPDATE_NODES" + + +@dataclass +class ClientRouteProxy: + """ + :param connection_id: String identifying the connection (required) + :param connection_addr_override:: Optional string address for initial connection + """ + + connection_id: str + connection_addr_override: Optional[str] = None + + def __post_init__(self): + if self.connection_id is None: + raise ValueError("connection_id is required") + +class ClientRoutesConfig: + """ + Configuration for client routes (Private Link support). + + :param proxies: List of :class:`ClientRouteProxy` objects + (REQUIRED, at least one) + :param advanced_shard_awareness: Whether to enable advanced shard awareness + (default: ``False``) + """ + + proxies: List[ClientRouteProxy] + advanced_shard_awareness: bool + + def __init__(self, proxies: List[ClientRouteProxy], advanced_shard_awareness: bool = False): + """ + :param proxies: List of ClientRouteProxy objects + :param advanced_shard_awareness: Enable advanced shard awareness (default False) + """ + if not proxies: + raise ValueError("At least one proxy must be specified") + + if not isinstance(proxies, (list, tuple)): + raise TypeError("proxies must be a list or tuple") + + for proxy in proxies: + if not isinstance(proxy, ClientRouteProxy): + raise TypeError("All proxies must be ClientRouteProxy instances") + + self.proxies = proxies + self.advanced_shard_awareness = advanced_shard_awareness + + def __repr__(self) -> str: + return (f"ClientRoutesConfig(proxies={self.proxies}, " + f"advanced_shard_awareness={self.advanced_shard_awareness})") + + +@dataclass(frozen=True) +class _Route: + connection_id: str + host_id: uuid.UUID + address: str # ipv4, ipv6 or DNS hostname from system.client_routes + port: int + +class _RouteStore: + """ + Thread-safe storage for routes. Reads are safe under CPython's GIL; + writes are serialized with a lock. + + This uses atomic pointer swaps for updates, allowing lock-free reads + while serializing writes. + """ + + _routes_by_host_id: Dict[uuid.UUID, _Route] + _lock: threading.Lock + + def __init__(self) -> None: + self._routes_by_host_id = {} + self._lock = threading.Lock() + + def get_by_host_id(self, host_id: uuid.UUID) -> Optional[_Route]: + """ + Get route for a host ID (lock-free read). + + :param host_id: UUID of the host + :return: _Route or None + """ + return self._routes_by_host_id.get(host_id) + + def get_all(self) -> List[_Route]: + """ + Get all routes as a list (lock-free read). + + :return: List of _Route + """ + return list(self._routes_by_host_id.values()) + + def _select_preferred_routes(self, new_routes: List[_Route]) -> List[_Route]: + """ + When multiple routes exist for the same host_id (different connection_ids), + prefer the connection_id already in use. Only migrate to a different + connection_id when the previously used one is no longer available. + + Must be called under self._lock. + """ + by_host: Dict[uuid.UUID, List[_Route]] = {} + for route in new_routes: + by_host.setdefault(route.host_id, []).append(route) + + selected = [] + for host_id, candidates in by_host.items(): + if len(candidates) == 1: + selected.append(candidates[0]) + continue + + existing = self._routes_by_host_id.get(host_id) + if existing: + preferred = [c for c in candidates if c.connection_id == existing.connection_id] + if preferred: + selected.append(preferred[0]) + continue + + selected.append(candidates[0]) + + return selected + + def update(self, routes: List[_Route]) -> None: + """ + Replace all routes atomically. + + :param routes: List of _Route objects + """ + with self._lock: + preferred = self._select_preferred_routes(routes) + self._routes_by_host_id = {route.host_id: route for route in preferred} + + def merge(self, new_routes: List[_Route], affected_host_ids: Set[uuid.UUID]) -> None: + """ + Merge new routes with existing ones atomically. + + Routes for affected_host_ids are replaced entirely: existing routes + for those hosts are dropped and replaced with whatever is in new_routes. + This handles deletions from system.client_routes (affected host present + but no new route for it). + + :param new_routes: List of _Route objects to merge + :param affected_host_ids: Set of host IDs affected by the change. + """ + with self._lock: + preferred = self._select_preferred_routes(new_routes) + new_by_host = {r.host_id: r for r in preferred} + + updated = {hid: r for hid, r in self._routes_by_host_id.items() + if hid not in affected_host_ids} + updated.update(new_by_host) + self._routes_by_host_id = updated + + +class _ClientRoutesHandler: + """ + Handles dynamic address translation for Private Link via system.client_routes. + + Lifecycle: + 1. Construction: Create with configuration + 2. Initialization: Read system.client_routes after control connection established + 3. Steady state: Listen for CLIENT_ROUTES_CHANGE events and update routes + 4. Translation: Translate addresses using Host ID lookup + """ + + config: 'ClientRoutesConfig' + ssl_enabled: bool + _routes: _RouteStore + _connection_ids: Set[str] + _proxy_addresses_override: Dict[str, str] + + def __init__(self, config: 'ClientRoutesConfig', ssl_enabled: bool = False): + """ + :param config: ClientRoutesConfig instance + :param ssl_enabled: Whether TLS is enabled (determines port selection) + """ + if not isinstance(config, ClientRoutesConfig): + raise TypeError("config must be a ClientRoutesConfig instance") + + self.config = config + self.ssl_enabled = ssl_enabled + self._routes = _RouteStore() + self._connection_ids = {dep.connection_id for dep in config.proxies} + # Precalculate proxy address mappings for efficient lookup + self._proxy_addresses_override = { + proxy.connection_id: proxy.connection_addr_override + for proxy in config.proxies + if proxy.connection_addr_override + } + + def initialize(self, connection: 'Connection', timeout: float) -> None: + """ + Load all routes from system.client_routes. + + Called once at startup and again whenever the control connection + is re-established. Reads all configured connection IDs and + replaces the in-memory route store atomically. + + Raises on failure so the caller can decide how to react (e.g. + abort startup or schedule a reconnect). + + :param connection: The Connection instance to execute queries on + :param timeout: Query timeout in seconds + """ + log.info("[client routes] Loading routes for %d proxies", len(self.config.proxies)) + + routes = self._query_all_routes_for_connections(connection, timeout, self._connection_ids) + self._routes.update(routes) + + def handle_client_routes_change(self, connection: 'Connection', timeout: float, + change_type: 'ClientRoutesChangeType', + connection_ids: Sequence[str], host_ids: Sequence[str]) -> None: + """ + Handle CLIENT_ROUTES_CHANGE event. + + Currently the protocol defines only :attr:`ClientRoutesChangeType.UPDATE_NODES`. + New variants will be added to the enum if the protocol is extended. + + :param connection: The Connection instance to execute queries on + :param timeout: Query timeout in seconds + :param change_type: A :class:`ClientRoutesChangeType` value + :param connection_ids: Affected connection ID strings; empty means all. + :param host_ids: Affected host ID strings; empty means all. + """ + + full_refresh = False + if not connection_ids or not host_ids: + log.warning( + "[client routes] CLIENT_ROUTES_CHANGE has no connection_ids or host_ids, doing full refresh") + full_refresh = True + elif len(connection_ids) != len(host_ids): + log.warning("[client routes] CLIENT_ROUTES_CHANGE has mismatched lengths (conn: %d, host: %d), doing full refresh", + len(connection_ids), len(host_ids)) + full_refresh = True + + if full_refresh: + routes = self._query_all_routes_for_connections(connection, timeout, self._connection_ids) + self._routes.update(routes) + return + + host_uuids = [uuid.UUID(hid) for hid in host_ids] + pairs = [(cid, hid) for cid, hid in zip(connection_ids, host_uuids) + if cid in self._connection_ids] + + if not pairs: + return + + routes = self._query_routes_for_change_event(connection, timeout, pairs) + self._routes.merge(routes, affected_host_ids=set(host_uuids)) + + def _query_all_routes_for_connections(self, connection: 'Connection', timeout: float, + connection_ids: Set[str]) -> List[_Route]: + """ + Query all routes for the given connection IDs (complete refresh). + + Used when control connection reconnects or as a fallback when + CLIENT_ROUTES_CHANGE event has malformed data. + + :param connection: Connection to execute query on + :param timeout: Query timeout in seconds + :param connection_ids: Set of connection ID strings + :return: List of _Route + """ + if not connection_ids: + return [] + + placeholders = ', '.join('?' for _ in connection_ids) + query = f"SELECT connection_id, host_id, address, port, tls_port FROM system.client_routes WHERE connection_id IN ({placeholders})" + params = [cid.encode('utf-8') for cid in connection_ids] + + log.debug("[client routes] Querying all routes for connection_ids=%s", connection_ids) + return self._execute_routes_query(connection, timeout, query, params) + + def _query_routes_for_change_event(self, connection: 'Connection', timeout: float, + route_pairs: List[Tuple[str, uuid.UUID]]) -> List[_Route]: + """ + Query specific routes affected by a CLIENT_ROUTES_CHANGE event. + + Takes a list of (connection_id, host_id) pairs that represent the exact + routes affected by an operation. This provides precise updates without + fetching unrelated routes. + + If the pairs list is empty or None, falls back to a complete refresh + of all routes for safety. + + :param connection: Connection to execute query on + :param timeout: Query timeout in seconds + :param route_pairs: List of (connection_id, host_id) tuples + :return: List of _Route + """ + unique_pairs = list(dict.fromkeys(route_pairs)) + + conn_ids = list(dict.fromkeys(cid for cid, _ in unique_pairs)) + host_ids = list(dict.fromkeys(hid for _, hid in unique_pairs)) + + log.debug("[client routes] Querying route pairs from CLIENT_ROUTES_CHANGE " + "(first 5 of %d): %s", len(unique_pairs), unique_pairs[:5]) + + conn_ph = ', '.join('?' for _ in conn_ids) + host_ph = ', '.join('?' for _ in host_ids) + query = ( + "SELECT connection_id, host_id, address, port, tls_port " + "FROM system.client_routes " + f"WHERE connection_id IN ({conn_ph}) AND host_id IN ({host_ph})" + ) + params: List = [cid.encode('utf-8') for cid in conn_ids] + params.extend(hid.bytes for hid in host_ids) + + return self._execute_routes_query(connection, timeout, query, params) + + def _execute_routes_query(self, connection: 'Connection', timeout: float, + query: str, params: List) -> List[_Route]: + """ + Execute a routes query and parse results. + + Common helper for both complete refresh and change event queries. + + :param connection: Connection to execute query on + :param timeout: Query timeout in seconds + :param query: CQL query string + :param params: Query parameters + :return: List of _Route + """ + log.debug("[client routes] Executing query: %s with %d parameters", query, len(params)) + + query_msg = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE, + query_params=params if params else None) + result = connection.wait_for_response( + query_msg, timeout=timeout + ) + + routes = [] + broken = 0 + rows = dict_factory(result.column_names, result.parsed_rows) + for row in rows: + try: + absent = [] + port = row['tls_port'] if self.ssl_enabled else row['port'] + connection_id = row['connection_id'] + host_id = row['host_id'] + address = row['address'] + + if not port: + absent.append("tls_port" if self.ssl_enabled else "port") + if not connection_id: + absent.append("connection_id") + if not host_id: + absent.append("host_id") + if not address: + absent.append("address") + + if absent: + log.error("[client routes] read a route %s, that has no values for the following fields: %s", row, ",".join(absent)) + broken += 1 + continue + + final_address = self._proxy_addresses_override.get(connection_id, address) + + routes.append(_Route( + connection_id=connection_id, + host_id=host_id, + address=final_address, + port=port, + )) + except Exception as e: + log.warning("[client routes] Failed to parse route row: %s", e) + broken += 1 + + if broken and not routes: + raise RuntimeError( + "[client routes] All %d route rows failed validation; " + "refusing to return empty result that would wipe the route store" % broken + ) + + return routes + + def resolve_host(self, host_id: uuid.UUID) -> Optional[Tuple[str, int]]: + """ + Resolve a host_id to an (address, port) pair. + + Looks up the current route and selects the appropriate port. + + :param host_id: Host UUID to resolve + :return: Tuple of (address, port) or None if no route mapping exists + """ + route = self._routes.get_by_host_id(host_id) + if route is None: + return None + + if not route.port: + raise ValueError("Mapping for host %s has no port" % host_id) + + try: + result = socket.getaddrinfo(route.address, route.port, + socket.AF_UNSPEC, socket.SOCK_STREAM) + if not result: + raise socket.gaierror("No addresses found for %s" % route.address) + resolved_ip = result[0][4][0] + return resolved_ip, route.port + except socket.gaierror as e: + log.warning('[client routes] Could not resolve hostname "%s" (host_id=%s): %s', + route.address, host_id, e) + raise diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 51d0b2d88b..9eace8810d 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -29,7 +29,7 @@ from itertools import groupby, count, chain import json import logging -from typing import Optional, Union +from typing import Any, Dict, Optional, Union from warnings import warn from random import random import re @@ -48,7 +48,8 @@ SchemaTargetType, DriverException, ProtocolVersion, UnresolvableContactPoints, DependencyException) from cassandra.auth import _proxy_execute_key, PlainTextAuthProvider -from cassandra.connection import (ConnectionException, ConnectionShutdown, +from cassandra.client_routes import ClientRoutesChangeType, ClientRoutesConfig, _ClientRoutesHandler +from cassandra.connection import (ClientRoutesEndPointFactory, ConnectionException, ConnectionShutdown, ConnectionHeartbeat, ProtocolVersionUnsupported, EndPoint, DefaultEndPoint, DefaultEndPointFactory, SniEndPointFactory, ConnectionBusy, locally_supported_compressions) @@ -1215,7 +1216,8 @@ def __init__(self, shard_aware_options=None, metadata_request_timeout: Optional[float] = None, column_encryption_policy=None, - application_info:Optional[ApplicationInfoBase]=None + application_info:Optional[ApplicationInfoBase]=None, + client_routes_config:Optional[ClientRoutesConfig]=None ): """ ``executor_threads`` defines the number of threads in a pool for handling asynchronous tasks such as @@ -1280,6 +1282,45 @@ def __init__(self, if column_encryption_policy is not None: self.column_encryption_policy = column_encryption_policy + if client_routes_config is not None and endpoint_factory is not None: + raise ValueError("client_routes_config and endpoint_factory are mutually exclusive") + + self._client_routes_handler = None + if client_routes_config is not None: + if not isinstance(client_routes_config, ClientRoutesConfig): + raise TypeError("client_routes_config must be a ClientRoutesConfig instance") + + # SSL hostname verification is incompatible with client routes: + # connections go through NLB proxies whose addresses won't match + # server certificates. + _check_hostname_enabled = False + if ssl_context is not None and ssl_context.check_hostname: + _check_hostname_enabled = True + if ssl_options is not None and ssl_options.get('check_hostname', False): + _check_hostname_enabled = True + if _check_hostname_enabled: + raise ValueError( + "SSL hostname verification (check_hostname=True) is currently incompatible " + "with client_routes_config. When using client routes, connections " + "go through NLB proxies whose addresses won't match server " + "certificates. Disable hostname verification by setting " + "ssl_context.check_hostname = False." + ) + + ssl_enabled = ssl_context is not None or ssl_options is not None + self._client_routes_handler = _ClientRoutesHandler(client_routes_config, ssl_enabled=ssl_enabled) + + if contact_points is _NOT_SET or not self._contact_points_explicit: + seed_addrs = [dep.connection_addr_override for dep in client_routes_config.proxies + if dep.connection_addr_override] + if seed_addrs: + self.contact_points = seed_addrs + self._contact_points_explicit = True + log.info("[client routes] Using %d deployment connection addresses as contact points", + len(seed_addrs)) + + if self._client_routes_handler is not None: + endpoint_factory = ClientRoutesEndPointFactory(self._client_routes_handler, self.port) self.endpoint_factory = endpoint_factory or DefaultEndPointFactory(port=self.port) self.endpoint_factory.configure(self) @@ -1437,6 +1478,10 @@ def __init__(self, self.monitor_reporting_interval = monitor_reporting_interval self.shard_aware_options = ShardAwareOptions(opts=shard_aware_options) + if (client_routes_config is not None + and not client_routes_config.advanced_shard_awareness): + self.shard_aware_options.disable_shardaware_port = True + self._listeners = set() self._listener_lock = Lock() @@ -3612,11 +3657,21 @@ def _try_connect(self, endpoint): # this object (after a dereferencing a weakref) self_weakref = weakref.ref(self, partial(_clear_watcher, weakref.proxy(connection))) try: - connection.register_watchers({ + watchers = { "TOPOLOGY_CHANGE": partial(_watch_callback, self_weakref, '_handle_topology_change'), "STATUS_CHANGE": partial(_watch_callback, self_weakref, '_handle_status_change'), "SCHEMA_CHANGE": partial(_watch_callback, self_weakref, '_handle_schema_change') - }, register_timeout=self._timeout) + } + + if self._cluster._client_routes_handler is not None: + watchers["CLIENT_ROUTES_CHANGE"] = partial(_watch_callback, self_weakref, '_handle_client_routes_change') + + connection.register_watchers(watchers, register_timeout=self._timeout) + + if self._cluster._client_routes_handler is not None: + self._cluster._client_routes_handler.initialize( + connection, + self._timeout) sel_peers = self._get_peers_query(self.PeersQueryType.PEERS, connection) sel_local = self._SELECT_LOCAL if self._token_meta_enabled else self._SELECT_LOCAL_NO_TOKENS @@ -3979,6 +4034,44 @@ def _handle_status_change(self, event): # this will be run by the scheduler self._cluster.on_down(host, is_host_addition=False) + def _handle_client_routes_change(self, event: Dict[str, Any]) -> None: + """ + Handle CLIENT_ROUTES_CHANGE event from the server. + + This event indicates that the system.client_routes table has been updated + and we need to refresh our route mappings. + """ + if self._cluster._client_routes_handler is None: + log.warning("[control connection] Received CLIENT_ROUTES_CHANGE but no handler configured") + return + + raw_change_type = event.get("change_type") + try: + change_type = ClientRoutesChangeType(raw_change_type) + except ValueError: + log.warning("[control connection] Unknown CLIENT_ROUTES_CHANGE type: %s", raw_change_type) + return + + connection_ids = tuple(event.get("connection_ids", [])) + host_ids = tuple(event.get("host_ids", [])) + + self._cluster.scheduler.schedule_unique( + 0, + self._handle_client_routes_refresh, + self._connection, self._timeout, change_type, connection_ids, host_ids + ) + + def _handle_client_routes_refresh(self, connection, timeout, + change_type, connection_ids, host_ids): + try: + self._cluster._client_routes_handler.handle_client_routes_change( + connection, timeout, change_type, connection_ids, host_ids) + except ReferenceError: + pass # our weak reference to the Cluster is no good + except Exception: + log.debug("[control connection] Error handling CLIENT_ROUTES_CHANGE", exc_info=True) + self._signal_error() + def _handle_schema_change(self, event): if self._schema_event_refresh_window < 0: return @@ -4024,7 +4117,8 @@ def wait_for_schema_agreement(self, connection=None, preloaded_results=None, wai local_query = QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_SCHEMA_LOCAL, self._metadata_request_timeout), consistency_level=cl) try: - timeout = min(self._timeout, total_timeout - elapsed) + remaining = total_timeout - elapsed + timeout = min(self._timeout, remaining) if self._timeout is not None else remaining peers_result, local_result = connection.wait_for_responses( peers_query, local_query, timeout=timeout) except OperationTimedOut as timeout: diff --git a/cassandra/connection.py b/cassandra/connection.py index 87f860f32b..72b273ec37 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -25,12 +25,14 @@ from threading import Thread, Event, RLock, Condition import time import ssl +import uuid import weakref import random import itertools -from typing import Optional, Union +from typing import Any, Dict, Optional, Tuple, Union from cassandra.application_info import ApplicationInfoBase +from cassandra.client_routes import _ClientRoutesHandler from cassandra.protocol_features import ProtocolFeatures if 'gevent.monkey' in sys.modules: @@ -230,7 +232,7 @@ class DefaultEndPointFactory(EndPointFactory): port = None """ If no port is discovered in the row, this is the default port - used for endpoint creation. + used for endpoint creation. """ def __init__(self, port=None): @@ -328,6 +330,50 @@ def create_from_sni(self, sni): return SniEndPoint(self._proxy_address, sni, self._port) +class ClientRoutesEndPointFactory(EndPointFactory): + """ + EndPointFactory for Client Routes (Private Link) support. + + Creates ClientRoutesEndPoint instances that defer both address translation + (host_id -> hostname lookup) and DNS resolution until connection time. + This ensures immediate reaction to infrastructure changes. + """ + + client_routes_handler: _ClientRoutesHandler + default_port: int + + def __init__(self, client_routes_handler: _ClientRoutesHandler, default_port: int = None) -> None: + """ + :param client_routes_handler: _ClientRoutesHandler instance to lookup routes + :param default_port: Default port if none found in row + """ + self.client_routes_handler = client_routes_handler + self.default_port = default_port + + def create(self, row: Dict[str, Any]) -> 'ClientRoutesEndPoint': + """ + Create a ClientRoutesEndPoint from a system.peers row. + + Stores only the host_id and handler reference. Both translation + (route lookup) and DNS resolution happen later in resolve(). + """ + from cassandra.metadata import _NodeInfo + host_id = row.get("host_id") + + if host_id is None: + raise ValueError("No host_id to create ClientRoutesEndPoint") + + addr = _NodeInfo.get_broadcast_rpc_address(row) + port = _NodeInfo.get_broadcast_rpc_port(row) or _NodeInfo.get_broadcast_port(row) or self.default_port + + return ClientRoutesEndPoint( + host_id=host_id, + handler=self.client_routes_handler, + original_address=addr, + original_port=port, + ) + + @total_ordering class UnixSocketEndPoint(EndPoint): """ @@ -369,6 +415,76 @@ def __repr__(self): return "<%s: %s>" % (self.__class__.__name__, self._unix_socket_path) +@total_ordering +class ClientRoutesEndPoint(EndPoint): + """ + Client Routes (Private Link) EndPoint implementation. + + Defers both address translation (route lookup) and DNS resolution + until resolve() is called at connection time. This ensures immediate + reaction to infrastructure changes and CLIENT_ROUTES_CHANGE events. + """ + + _host_id: uuid.UUID + _handler: _ClientRoutesHandler + _original_address: str + _original_port: int + + def __init__(self, host_id: uuid.UUID, handler: _ClientRoutesHandler, original_address: str, original_port: int = None) -> None: + """ + :param host_id: Host UUID for route lookup + :param handler: _ClientRoutesHandler instance + :param original_address: Original address from system.peers (for identification) + :param original_port: Original port if route doesn't specify one + """ + self._host_id = host_id + self._handler = handler + self._original_address = original_address + self._original_port = original_port + + @property + def address(self) -> str: + """Returns the original address (updated by resolve()).""" + return self._original_address + + @property + def port(self) -> Optional[int]: + return self._original_port + + @property + def host_id(self) -> uuid.UUID: + return self._host_id + + def resolve(self) -> Tuple[str, int]: + """ + Resolve endpoint by delegating to the handler. + Falls back to original address/port if no route mapping is available. + """ + result = self._handler.resolve_host(self._host_id) + if result is None: + return self._original_address, self._original_port + return result + + def __eq__(self, other): + return (isinstance(other, ClientRoutesEndPoint) and + self._host_id == other._host_id and + self._original_address == other._original_address) + + def __hash__(self): + return hash((self._host_id, self._original_address)) + + def __lt__(self, other): + return ((self._host_id, self._original_address) < + (other._host_id, other._original_address)) + + def __str__(self): + return str("%s (host_id=%s)" % (self._original_address, self._host_id)) + + def __repr__(self): + return "<%s: host_id=%s, original_addr=%s>" % ( + self.__class__.__name__, self._host_id, self._original_address) + + class _Frame(object): def __init__(self, version, flags, stream, opcode, body_offset, end_pos): self.version = version diff --git a/cassandra/cqltypes.py b/cassandra/cqltypes.py index d33e5fceb8..11ab694cb3 100644 --- a/cassandra/cqltypes.py +++ b/cassandra/cqltypes.py @@ -249,6 +249,8 @@ def lookup_casstype(casstype): """ if isinstance(casstype, (CassandraType, CassandraTypeType)): return casstype + if '(' not in casstype: + return lookup_casstype_simple(casstype) try: return parse_casstype_args(casstype) except (ValueError, AssertionError, IndexError) as e: @@ -273,6 +275,7 @@ class _CassandraType(object, metaclass=CassandraTypeType): subtypes = () num_subtypes = 0 empty_binary_ok = False + _apply_parameters_cache = {} support_empty_values = False """ @@ -371,8 +374,15 @@ def apply_parameters(cls, subtypes, names=None): if cls.num_subtypes != 'UNKNOWN' and len(subtypes) != cls.num_subtypes: raise ValueError("%s types require %d subtypes (%d given)" % (cls.typename, cls.num_subtypes, len(subtypes))) + subtypes = tuple(subtypes) + cache_key = (cls, subtypes, tuple(names) if names else names) + cached = cls._apply_parameters_cache.get(cache_key) + if cached is not None: + return cached newname = cls.cass_parameterized_type_with(subtypes) - return type(newname, (cls,), {'subtypes': subtypes, 'cassname': cls.cassname, 'fieldnames': names}) + result = type(newname, (cls,), {'subtypes': subtypes, 'cassname': cls.cassname, 'fieldnames': names}) + cls._apply_parameters_cache[cache_key] = result + return result @classmethod def cql_parameterized_type(cls): diff --git a/cassandra/metadata.py b/cassandra/metadata.py index b85308449e..43399b7152 100644 --- a/cassandra/metadata.py +++ b/cassandra/metadata.py @@ -2577,6 +2577,10 @@ class SchemaParserV3(SchemaParserV22): _SELECT_AGGREGATES = "SELECT * FROM system_schema.aggregates" _SELECT_VIEWS = "SELECT * FROM system_schema.views" + def _is_not_scylla(self): + """Check if NOT connected to ScyllaDB by checking for shard awareness.""" + return getattr(getattr(self.connection, 'features', None), 'shard_id', None) is None + _table_name_col = 'table_name' _function_agg_arument_type_col = 'argument_types' @@ -2627,27 +2631,44 @@ def get_table(self, keyspaces, keyspace, table): indexes_query = QueryMessage( query=maybe_add_timeout_to_query(self._SELECT_INDEXES + where_clause, self.metadata_request_timeout), consistency_level=cl, fetch_size=fetch_size) - triggers_query = QueryMessage( - query=maybe_add_timeout_to_query(self._SELECT_TRIGGERS + where_clause, self.metadata_request_timeout), - consistency_level=cl, fetch_size=fetch_size) + + # ScyllaDB doesn't have triggers, skip the query + if self._is_not_scylla(): + triggers_query = QueryMessage( + query=maybe_add_timeout_to_query(self._SELECT_TRIGGERS + where_clause, self.metadata_request_timeout), + consistency_level=cl, fetch_size=fetch_size) # in protocol v4 we don't know if this event is a view or a table, so we look for both where_clause = bind_params(" WHERE keyspace_name = %s AND view_name = %s", (keyspace, table), _encoder) view_query = QueryMessage( query=maybe_add_timeout_to_query(self._SELECT_VIEWS + where_clause, self.metadata_request_timeout), consistency_level=cl, fetch_size=fetch_size) - ((cf_success, cf_result), (col_success, col_result), - (indexes_sucess, indexes_result), (triggers_success, triggers_result), - (view_success, view_result)) = ( - self.connection.wait_for_responses( - cf_query, col_query, indexes_query, triggers_query, - view_query, timeout=self.timeout, fail_on_error=False) - ) + + if self._is_not_scylla(): + ((cf_success, cf_result), (col_success, col_result), + (indexes_sucess, indexes_result), (triggers_success, triggers_result), + (view_success, view_result)) = ( + self.connection.wait_for_responses( + cf_query, col_query, indexes_query, triggers_query, + view_query, timeout=self.timeout, fail_on_error=False) + ) + else: + ((cf_success, cf_result), (col_success, col_result), + (indexes_sucess, indexes_result), + (view_success, view_result)) = ( + self.connection.wait_for_responses( + cf_query, col_query, indexes_query, + view_query, timeout=self.timeout, fail_on_error=False) + ) + table_result = self._handle_results(cf_success, cf_result, query_msg=cf_query) col_result = self._handle_results(col_success, col_result, query_msg=col_query) if table_result: indexes_result = self._handle_results(indexes_sucess, indexes_result, query_msg=indexes_query) - triggers_result = self._handle_results(triggers_success, triggers_result, query_msg=triggers_query) + if self._is_not_scylla(): + triggers_result = self._handle_results(triggers_success, triggers_result, query_msg=triggers_query) + else: + triggers_result = None return self._build_table_metadata(table_result[0], col_result, triggers_result, indexes_result) view_result = self._handle_results(view_success, view_result, query_msg=view_query) @@ -2696,9 +2717,10 @@ def _build_table_metadata(self, row, col_rows=None, trigger_rows=None, index_row self._build_table_columns(table_meta, col_rows, compact_static, is_dense, virtual) - for trigger_row in trigger_rows: - trigger_meta = self._build_trigger_metadata(table_meta, trigger_row) - table_meta.triggers[trigger_meta.name] = trigger_meta + if self._is_not_scylla(): + for trigger_row in trigger_rows: + trigger_meta = self._build_trigger_metadata(table_meta, trigger_row) + table_meta.triggers[trigger_meta.name] = trigger_meta for index_row in index_rows: index_meta = self._build_index_metadata(table_meta, index_row) @@ -2741,7 +2763,7 @@ def _build_table_columns(self, meta, col_rows, compact_static=False, is_dense=Fa meta.clustering_key.append(meta.columns[r.get('column_name')]) for col_row in (r for r in col_rows - if r.get('kind', None) not in ('partition_key', 'clustering_key')): + if r.get('kind', None) not in ('partition_key', 'clustering')): column_meta = self._build_column_metadata(meta, col_row) if is_dense and column_meta.cql_type == types.cql_empty_type: continue @@ -2793,6 +2815,7 @@ def _build_trigger_metadata(table_metadata, row): trigger_meta = TriggerMetadata(table_metadata, name, options) return trigger_meta + def _query_all(self): cl = ConsistencyLevel.ONE fetch_size = self.fetch_size @@ -2809,35 +2832,45 @@ def _query_all(self): fetch_size=fetch_size, consistency_level=cl), QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_AGGREGATES, self.metadata_request_timeout), fetch_size=fetch_size, consistency_level=cl), - QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_TRIGGERS, self.metadata_request_timeout), - fetch_size=fetch_size, consistency_level=cl), QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_INDEXES, self.metadata_request_timeout), fetch_size=fetch_size, consistency_level=cl), QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_VIEWS, self.metadata_request_timeout), fetch_size=fetch_size, consistency_level=cl), ] + # ScyllaDB doesn't have triggers, skip the query + if self._is_not_scylla(): + queries.append(QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_TRIGGERS, self.metadata_request_timeout), + fetch_size=fetch_size, consistency_level=cl)) + + responses = self.connection.wait_for_responses(*queries, timeout=self.timeout, fail_on_error=False) + + # Unpack common responses (always present) ((ks_success, ks_result), (table_success, table_result), (col_success, col_result), (types_success, types_result), (functions_success, functions_result), (aggregates_success, aggregates_result), - (triggers_success, triggers_result), (indexes_success, indexes_result), - (views_success, views_result)) = self.connection.wait_for_responses( - *queries, timeout=self.timeout, fail_on_error=False - ) + (views_success, views_result)) = responses[:8] + + # Unpack triggers response if present (Cassandra/DSE only) + if self._is_not_scylla(): + (triggers_success, triggers_result) = responses[8] self.keyspaces_result = self._handle_results(ks_success, ks_result, query_msg=queries[0]) self.tables_result = self._handle_results(table_success, table_result, query_msg=queries[1]) self.columns_result = self._handle_results(col_success, col_result, query_msg=queries[2]) - self.triggers_result = self._handle_results(triggers_success, triggers_result, query_msg=queries[6]) self.types_result = self._handle_results(types_success, types_result, query_msg=queries[3]) self.functions_result = self._handle_results(functions_success, functions_result, query_msg=queries[4]) self.aggregates_result = self._handle_results(aggregates_success, aggregates_result, query_msg=queries[5]) - self.indexes_result = self._handle_results(indexes_success, indexes_result, query_msg=queries[7]) - self.views_result = self._handle_results(views_success, views_result, query_msg=queries[8]) + self.indexes_result = self._handle_results(indexes_success, indexes_result, query_msg=queries[6]) + self.views_result = self._handle_results(views_success, views_result, query_msg=queries[7]) + if self._is_not_scylla(): + self.triggers_result = self._handle_results(triggers_success, triggers_result, query_msg=queries[8]) + else: + self.triggers_result = [] self._aggregate_results() @@ -2915,8 +2948,6 @@ def _query_all(self): fetch_size=fetch_size, consistency_level=cl), QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_AGGREGATES, self.metadata_request_timeout), fetch_size=fetch_size, consistency_level=cl), - QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_TRIGGERS, self.metadata_request_timeout), - fetch_size=fetch_size, consistency_level=cl), QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_INDEXES, self.metadata_request_timeout), fetch_size=fetch_size, consistency_level=cl), QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_VIEWS, self.metadata_request_timeout), @@ -2930,8 +2961,15 @@ def _query_all(self): fetch_size=fetch_size, consistency_level=cl), ] + # ScyllaDB doesn't have triggers, skip the query + if self._is_not_scylla(): + queries.append(QueryMessage(query=maybe_add_timeout_to_query(self._SELECT_TRIGGERS, self.metadata_request_timeout), + fetch_size=fetch_size, consistency_level=cl)) + responses = self.connection.wait_for_responses( *queries, timeout=self.timeout, fail_on_error=False) + + # Unpack common responses (always present) ( # copied from V3 (ks_success, ks_result), @@ -2940,39 +2978,45 @@ def _query_all(self): (types_success, types_result), (functions_success, functions_result), (aggregates_success, aggregates_result), - (triggers_success, triggers_result), (indexes_success, indexes_result), (views_success, views_result), # V4-only responses (virtual_ks_success, virtual_ks_result), (virtual_table_success, virtual_table_result), - (virtual_column_success, virtual_column_result) - ) = responses + (virtual_column_success, virtual_column_result), + ) = responses[:11] + + # Unpack triggers response if present (Cassandra/DSE only) + if self._is_not_scylla(): + (triggers_success, triggers_result) = responses[11] # copied from V3 self.keyspaces_result = self._handle_results(ks_success, ks_result, query_msg=queries[0]) self.tables_result = self._handle_results(table_success, table_result, query_msg=queries[1]) self.columns_result = self._handle_results(col_success, col_result, query_msg=queries[2]) - self.triggers_result = self._handle_results(triggers_success, triggers_result, query_msg=queries[6]) self.types_result = self._handle_results(types_success, types_result, query_msg=queries[3]) self.functions_result = self._handle_results(functions_success, functions_result, query_msg=queries[4]) self.aggregates_result = self._handle_results(aggregates_success, aggregates_result, query_msg=queries[5]) - self.indexes_result = self._handle_results(indexes_success, indexes_result, query_msg=queries[7]) - self.views_result = self._handle_results(views_success, views_result, query_msg=queries[8]) + self.indexes_result = self._handle_results(indexes_success, indexes_result, query_msg=queries[6]) + self.views_result = self._handle_results(views_success, views_result, query_msg=queries[7]) + if self._is_not_scylla(): + self.triggers_result = self._handle_results(triggers_success, triggers_result, query_msg=queries[11]) + else: + self.triggers_result = [] # V4-only results # These tables don't exist in some DSE versions reporting 4.X so we can # ignore them if we got an error self.virtual_keyspaces_result = self._handle_results( virtual_ks_success, virtual_ks_result, - expected_failures=(InvalidRequest,), query_msg=queries[9] + expected_failures=(InvalidRequest,), query_msg=queries[8] ) self.virtual_tables_result = self._handle_results( virtual_table_success, virtual_table_result, - expected_failures=(InvalidRequest,), query_msg=queries[10] + expected_failures=(InvalidRequest,), query_msg=queries[9] ) self.virtual_columns_result = self._handle_results( virtual_column_success, virtual_column_result, - expected_failures=(InvalidRequest,), query_msg=queries[11] + expected_failures=(InvalidRequest,), query_msg=queries[10] ) self._aggregate_results() diff --git a/cassandra/protocol.py b/cassandra/protocol.py index f37633a756..4628c7ee0e 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -611,9 +611,10 @@ class QueryMessage(_QueryMessage): name = 'QUERY' def __init__(self, query, consistency_level, serial_consistency_level=None, - fetch_size=None, paging_state=None, timestamp=None, continuous_paging_options=None, keyspace=None): + fetch_size=None, paging_state=None, timestamp=None, continuous_paging_options=None, keyspace=None, + query_params=None): self.query = query - super(QueryMessage, self).__init__(None, consistency_level, serial_consistency_level, fetch_size, + super(QueryMessage, self).__init__(query_params, consistency_level, serial_consistency_level, fetch_size, paging_state, timestamp, False, continuous_paging_options, keyspace) def send_body(self, f, protocol_version): diff --git a/docs/conf.py b/docs/conf.py index 403908c29e..4b6b329525 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -29,10 +29,11 @@ '3.29.6-scylla', '3.29.7-scylla', '3.29.8-scylla', + '3.29.9-scylla', ] BRANCHES = ['master'] # Set the latest version. -LATEST_VERSION = '3.29.8-scylla' +LATEST_VERSION = '3.29.9-scylla' # Set which versions are not released yet. UNSTABLE_VERSIONS = ['master'] # Set which versions are deprecated diff --git a/docs/installation.rst b/docs/installation.rst index 4207c46092..7b4823b832 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -26,7 +26,7 @@ To check if the installation was successful, you can run:: python -c 'import cassandra; print(cassandra.__version__)' -It should print something like "3.29.8". +It should print something like "3.29.9". (*Optional*) Compression Support -------------------------------- @@ -199,7 +199,7 @@ through `Homebrew `_. For example, on Mac OS X:: $ brew install libev -The libev extension can now be built for Windows as of Python driver version 3.29.8. You can +The libev extension can now be built for Windows as of Python driver version 3.29.9. You can install libev using any Windows package manager. For example, to install using `vcpkg `_: $ vcpkg install libev diff --git a/docs/uv.lock b/docs/uv.lock index d6b5359d21..720a2080e7 100644 --- a/docs/uv.lock +++ b/docs/uv.lock @@ -614,11 +614,11 @@ wheels = [ [[package]] name = "pygments" -version = "2.19.2" +version = "2.20.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631, upload-time = "2025-06-21T13:39:12.283Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c3/b2/bc9c9196916376152d655522fdcebac55e66de6603a76a02bca1b6414f6c/pygments-2.20.0.tar.gz", hash = "sha256:6757cd03768053ff99f3039c1a36d6c0aa0b263438fcab17520b30a303a82b5f", size = 4955991, upload-time = "2026-03-29T13:29:33.898Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, + { url = "https://files.pythonhosted.org/packages/f4/7e/a72dd26f3b0f4f2bf1dd8923c85f7ceb43172af56d63c7383eb62b332364/pygments-2.20.0-py3-none-any.whl", hash = "sha256:81a9e26dd42fd28a23a2d169d86d7ac03b46e2f8b59ed4698fb4785f946d0176", size = 1231151, upload-time = "2026-03-29T13:29:30.038Z" }, ] [[package]] @@ -1040,21 +1040,19 @@ wheels = [ [[package]] name = "tornado" -version = "6.5.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/09/ce/1eb500eae19f4648281bb2186927bb062d2438c2e5093d1360391afd2f90/tornado-6.5.2.tar.gz", hash = "sha256:ab53c8f9a0fa351e2c0741284e06c7a45da86afb544133201c5cc8578eb076a0", size = 510821, upload-time = "2025-08-08T18:27:00.78Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/f6/48/6a7529df2c9cc12efd2e8f5dd219516184d703b34c06786809670df5b3bd/tornado-6.5.2-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:2436822940d37cde62771cff8774f4f00b3c8024fe482e16ca8387b8a2724db6", size = 442563, upload-time = "2025-08-08T18:26:42.945Z" }, - { url = "https://files.pythonhosted.org/packages/f2/b5/9b575a0ed3e50b00c40b08cbce82eb618229091d09f6d14bce80fc01cb0b/tornado-6.5.2-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:583a52c7aa94ee046854ba81d9ebb6c81ec0fd30386d96f7640c96dad45a03ef", size = 440729, upload-time = "2025-08-08T18:26:44.473Z" }, - { url = "https://files.pythonhosted.org/packages/1b/4e/619174f52b120efcf23633c817fd3fed867c30bff785e2cd5a53a70e483c/tornado-6.5.2-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b0fe179f28d597deab2842b86ed4060deec7388f1fd9c1b4a41adf8af058907e", size = 444295, upload-time = "2025-08-08T18:26:46.021Z" }, - { url = "https://files.pythonhosted.org/packages/95/fa/87b41709552bbd393c85dd18e4e3499dcd8983f66e7972926db8d96aa065/tornado-6.5.2-cp39-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b186e85d1e3536d69583d2298423744740986018e393d0321df7340e71898882", size = 443644, upload-time = "2025-08-08T18:26:47.625Z" }, - { url = "https://files.pythonhosted.org/packages/f9/41/fb15f06e33d7430ca89420283a8762a4e6b8025b800ea51796ab5e6d9559/tornado-6.5.2-cp39-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e792706668c87709709c18b353da1f7662317b563ff69f00bab83595940c7108", size = 443878, upload-time = "2025-08-08T18:26:50.599Z" }, - { url = "https://files.pythonhosted.org/packages/11/92/fe6d57da897776ad2e01e279170ea8ae726755b045fe5ac73b75357a5a3f/tornado-6.5.2-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:06ceb1300fd70cb20e43b1ad8aaee0266e69e7ced38fa910ad2e03285009ce7c", size = 444549, upload-time = "2025-08-08T18:26:51.864Z" }, - { url = "https://files.pythonhosted.org/packages/9b/02/c8f4f6c9204526daf3d760f4aa555a7a33ad0e60843eac025ccfd6ff4a93/tornado-6.5.2-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:74db443e0f5251be86cbf37929f84d8c20c27a355dd452a5cfa2aada0d001ec4", size = 443973, upload-time = "2025-08-08T18:26:53.625Z" }, - { url = "https://files.pythonhosted.org/packages/ae/2d/f5f5707b655ce2317190183868cd0f6822a1121b4baeae509ceb9590d0bd/tornado-6.5.2-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:b5e735ab2889d7ed33b32a459cac490eda71a1ba6857b0118de476ab6c366c04", size = 443954, upload-time = "2025-08-08T18:26:55.072Z" }, - { url = "https://files.pythonhosted.org/packages/e8/59/593bd0f40f7355806bf6573b47b8c22f8e1374c9b6fd03114bd6b7a3dcfd/tornado-6.5.2-cp39-abi3-win32.whl", hash = "sha256:c6f29e94d9b37a95013bb669616352ddb82e3bfe8326fccee50583caebc8a5f0", size = 445023, upload-time = "2025-08-08T18:26:56.677Z" }, - { url = "https://files.pythonhosted.org/packages/c7/2a/f609b420c2f564a748a2d80ebfb2ee02a73ca80223af712fca591386cafb/tornado-6.5.2-cp39-abi3-win_amd64.whl", hash = "sha256:e56a5af51cc30dd2cae649429af65ca2f6571da29504a07995175df14c18f35f", size = 445427, upload-time = "2025-08-08T18:26:57.91Z" }, - { url = "https://files.pythonhosted.org/packages/5e/4f/e1f65e8f8c76d73658b33d33b81eed4322fb5085350e4328d5c956f0c8f9/tornado-6.5.2-cp39-abi3-win_arm64.whl", hash = "sha256:d6c33dc3672e3a1f3618eb63b7ef4683a7688e7b9e6e8f0d9aa5726360a004af", size = 444456, upload-time = "2025-08-08T18:26:59.207Z" }, +version = "6.5.5" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f8/f1/3173dfa4a18db4a9b03e5d55325559dab51ee653763bb8745a75af491286/tornado-6.5.5.tar.gz", hash = "sha256:192b8f3ea91bd7f1f50c06955416ed76c6b72f96779b962f07f911b91e8d30e9", size = 516006, upload-time = "2026-03-10T21:31:02.067Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/59/8c/77f5097695f4dd8255ecbd08b2a1ed8ba8b953d337804dd7080f199e12bf/tornado-6.5.5-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:487dc9cc380e29f58c7ab88f9e27cdeef04b2140862e5076a66fb6bb68bb1bfa", size = 445983, upload-time = "2026-03-10T21:30:44.28Z" }, + { url = "https://files.pythonhosted.org/packages/ab/5e/7625b76cd10f98f1516c36ce0346de62061156352353ef2da44e5c21523c/tornado-6.5.5-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:65a7f1d46d4bb41df1ac99f5fcb685fb25c7e61613742d5108b010975a9a6521", size = 444246, upload-time = "2026-03-10T21:30:46.571Z" }, + { url = "https://files.pythonhosted.org/packages/b2/04/7b5705d5b3c0fab088f434f9c83edac1573830ca49ccf29fb83bf7178eec/tornado-6.5.5-cp39-abi3-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:e74c92e8e65086b338fd56333fb9a68b9f6f2fe7ad532645a290a464bcf46be5", size = 447229, upload-time = "2026-03-10T21:30:48.273Z" }, + { url = "https://files.pythonhosted.org/packages/34/01/74e034a30ef59afb4097ef8659515e96a39d910b712a89af76f5e4e1f93c/tornado-6.5.5-cp39-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:435319e9e340276428bbdb4e7fa732c2d399386d1de5686cb331ec8eee754f07", size = 448192, upload-time = "2026-03-10T21:30:51.22Z" }, + { url = "https://files.pythonhosted.org/packages/be/00/fe9e02c5a96429fce1a1d15a517f5d8444f9c412e0bb9eadfbe3b0fc55bf/tornado-6.5.5-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:3f54aa540bdbfee7b9eb268ead60e7d199de5021facd276819c193c0fb28ea4e", size = 448039, upload-time = "2026-03-10T21:30:53.52Z" }, + { url = "https://files.pythonhosted.org/packages/82/9e/656ee4cec0398b1d18d0f1eb6372c41c6b889722641d84948351ae19556d/tornado-6.5.5-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:36abed1754faeb80fbd6e64db2758091e1320f6bba74a4cf8c09cd18ccce8aca", size = 447445, upload-time = "2026-03-10T21:30:55.541Z" }, + { url = "https://files.pythonhosted.org/packages/5a/76/4921c00511f88af86a33de770d64141170f1cfd9c00311aea689949e274e/tornado-6.5.5-cp39-abi3-win32.whl", hash = "sha256:dd3eafaaeec1c7f2f8fdcd5f964e8907ad788fe8a5a32c4426fbbdda621223b7", size = 448582, upload-time = "2026-03-10T21:30:57.142Z" }, + { url = "https://files.pythonhosted.org/packages/2c/23/f6c6112a04d28eed765e374435fb1a9198f73e1ec4b4024184f21faeb1ad/tornado-6.5.5-cp39-abi3-win_amd64.whl", hash = "sha256:6443a794ba961a9f619b1ae926a2e900ac20c34483eea67be4ed8f1e58d3ef7b", size = 448990, upload-time = "2026-03-10T21:30:58.857Z" }, + { url = "https://files.pythonhosted.org/packages/b7/c8/876602cbc96469911f0939f703453c1157b0c826ecb05bdd32e023397d4e/tornado-6.5.5-cp39-abi3-win_arm64.whl", hash = "sha256:2c9a876e094109333f888539ddb2de4361743e5d21eece20688e3e351e4990a6", size = 448016, upload-time = "2026-03-10T21:31:00.43Z" }, ] [[package]] diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py index dfac2dc1d9..2015e0663f 100644 --- a/tests/integration/__init__.py +++ b/tests/integration/__init__.py @@ -442,7 +442,7 @@ def use_cluster(cluster_name, nodes, ipformat=None, start=True, workloads=None, else: log.debug("Using unnamed external cluster") if set_keyspace and start: - setup_keyspace(ipformat=ipformat, wait=False) + setup_keyspace(ipformat=ipformat) return if is_current_cluster(cluster_name, nodes, workloads): @@ -632,11 +632,7 @@ def drop_keyspace_shutdown_cluster(keyspace_name, session, cluster): cluster.shutdown() -def setup_keyspace(ipformat=None, wait=True, protocol_version=None, port=9042): - # wait for nodes to startup - if wait: - time.sleep(10) - +def setup_keyspace(ipformat=None, protocol_version=None, port=9042): if protocol_version: _protocol_version = protocol_version else: @@ -715,6 +711,27 @@ def xfail_scylla_version_lt(reason, oss_scylla_version, ent_scylla_version, *arg return pytest.mark.xfail(current_version < Version(oss_scylla_version), reason=reason, *args, **kwargs) + +def skip_scylla_version_lt(reason, scylla_version): + """ + Skip tests on scylla versions older than the specified thresholds. + :param reason: message explaining why the test is skipped + :param scylla_version: str, version from which test supposed to work + """ + if not (reason.startswith("scylladb/scylladb#") or reason.startswith("scylladb/scylla-enterprise#")): + raise ValueError('reason should start with scylladb/scylladb# or scylladb/scylla-enterprise# to reference issue in scylla repo') + + if not isinstance(scylla_version, str): + raise ValueError('scylla_version should be a str') + + if SCYLLA_VERSION is None: + return pytest.mark.skipif(False, reason="It is just a NoOP Decor, should not skip anything") + + current_version = Version(get_scylla_version(SCYLLA_VERSION)) + + return pytest.mark.skipif(current_version < Version(scylla_version), reason=reason) + + class UpDownWaiter(object): def __init__(self, host): diff --git a/tests/integration/simulacron/test_connection.py b/tests/integration/simulacron/test_connection.py index 818d0b46b9..ceceea814f 100644 --- a/tests/integration/simulacron/test_connection.py +++ b/tests/integration/simulacron/test_connection.py @@ -23,7 +23,7 @@ from cassandra.policies import HostStateListener, RoundRobinPolicy, WhiteListRoundRobinPolicy from tests import connection_class, thread_pool_executor_class -from tests.util import late +from tests.util import late, wait_until_not_raised from tests.integration import requiressimulacron, libevtest from tests.integration.util import assert_quiescent_pool_state # important to import the patch PROTOCOL_VERSION from the simulacron module @@ -356,13 +356,15 @@ def test_retry_after_defunct(self): for _ in range(10): session.execute(query_to_prime) - # Might take some time to close the previous connections and reconnect - time.sleep(10) - assert_quiescent_pool_state(cluster) + # Wait for previous connections to close and pool to stabilize + wait_until_not_raised( + lambda: assert_quiescent_pool_state(cluster), + delay=1, max_attempts=30) clear_queries() - time.sleep(10) - assert_quiescent_pool_state(cluster) + wait_until_not_raised( + lambda: assert_quiescent_pool_state(cluster), + delay=1, max_attempts=30) def test_idle_connection_is_not_closed(self): """ diff --git a/tests/integration/simulacron/utils.py b/tests/integration/simulacron/utils.py index b6136e247a..2322319234 100644 --- a/tests/integration/simulacron/utils.py +++ b/tests/integration/simulacron/utils.py @@ -89,8 +89,13 @@ def start_simulacron(): SERVER_SIMULACRON.start() - # TODO improve this sleep, maybe check the logs like ccm - time.sleep(5) + # Poll the admin endpoint until simulacron is ready + def _check_simulacron_ready(): + opener = build_opener(HTTPHandler) + request = Request("http://127.0.0.1:8187/cluster") + opener.open(request, timeout=2) + + wait_until_not_raised(_check_simulacron_ready, delay=0.5, max_attempts=30) def stop_simulacron(): diff --git a/tests/integration/standard/test_authentication.py b/tests/integration/standard/test_authentication.py index eb8019bf65..502fdf8993 100644 --- a/tests/integration/standard/test_authentication.py +++ b/tests/integration/standard/test_authentication.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + from packaging.version import Version import logging import time @@ -38,18 +40,34 @@ def setup_module(): use_singledc(start=False) ccm_cluster = get_cluster() ccm_cluster.stop() - config_options = {'authenticator': 'PasswordAuthenticator', - 'authorizer': 'CassandraAuthorizer'} + config_options = { + 'authenticator': 'PasswordAuthenticator', + 'authorizer': 'CassandraAuthorizer', + 'auth_superuser_name': 'cassandra', + 'auth_superuser_salted_password': '$6$x7IFjiX5VCpvNiFk$2IfjTvSyGL7zerpV.wbY7mJjaRCrJ/68dtT3UpT.sSmNYz1bPjtn3mH.kJKFvaZ2T4SbVeBijjmwGjcb83LlV/' + } ccm_cluster.set_configuration_options(config_options) log.debug("Starting ccm test cluster with %s", config_options) start_cluster_wait_for_up(ccm_cluster) # PYTHON-1328 # - # Give the cluster enough time to startup (and perform necessary initialization) - # before executing the test. + # Wait for PasswordAuthenticator to finish initializing (creating the + # default superuser). Poll by attempting to authenticate rather than + # using a fixed sleep. if CASSANDRA_VERSION > Version('4.0-a'): - time.sleep(10) + from tests.util import wait_until_not_raised + + def _check_auth_ready(): + cluster = TestCluster(protocol_version=PROTOCOL_VERSION, + auth_provider=PlainTextAuthProvider('cassandra', 'cassandra')) + try: + session = cluster.connect() + session.execute("SELECT * FROM system.local WHERE key='local'") + finally: + cluster.shutdown() + + wait_until_not_raised(_check_auth_ready, delay=1, max_attempts=30) def teardown_module(): remove_cluster() # this test messes with config diff --git a/tests/integration/standard/test_client_routes.py b/tests/integration/standard/test_client_routes.py new file mode 100644 index 0000000000..a8a3c30f2c --- /dev/null +++ b/tests/integration/standard/test_client_routes.py @@ -0,0 +1,1314 @@ +# Copyright 2026 ScyllaDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Comprehensive integration tests for Client Routes (Private Link) support. + +Includes: +- TCP proxy and NLB emulator for simulating private link infrastructure +- Tests verifying all connections go exclusively through the proxy +- Tests for dynamic route updates and topology changes +- Tests for query_routes filtering +""" + +import logging +import os +import select +import shutil +import socket +import ssl +import subprocess +import tempfile +import threading +import time +import unittest +import uuid + +import json as _json +import urllib.request + +from cassandra.cluster import Cluster +from cassandra.client_routes import ClientRoutesConfig, ClientRouteProxy +from cassandra.connection import ClientRoutesEndPoint +from cassandra.policies import RoundRobinPolicy +from tests.integration import ( + TestCluster, + get_cluster, + get_node, + use_cluster, + wait_for_node_socket, + skip_scylla_version_lt, +) +from tests.util import wait_until_not_raised + +log = logging.getLogger(__name__) + +class TcpProxy: + """ + A simple TCP proxy that forwards connections from a local listen port + to a target (host, port). Tracks active connections so tests can + verify that traffic flows through the proxy. + """ + + BUF_SIZE = 65536 + + def __init__(self, listen_host, listen_port, target_host, target_port): + self.listen_host = listen_host + self.listen_port = listen_port + self.target_host = target_host + self.target_port = target_port + + self._server_sock = None + self._running = False + self._thread = None + self._lock = threading.Lock() + self._connections = set() + self.total_connections = 0 + + def start(self): + self._server_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self._server_sock.bind((self.listen_host, self.listen_port)) + self.listen_port = self._server_sock.getsockname()[1] + self._server_sock.listen(128) + self._server_sock.setblocking(False) + self._running = True + self._thread = threading.Thread(target=self._run, daemon=True, + name="proxy-%s:%d" % (self.listen_host, self.listen_port)) + self._thread.start() + log.info("TcpProxy started %s:%d -> %s:%d", + self.listen_host, self.listen_port, + self.target_host, self.target_port) + + def stop(self): + self._running = False + if self._server_sock: + try: + self._server_sock.close() + except Exception: + pass + with self._lock: + for csock, tsock in list(self._connections): + self._close_pair(csock, tsock) + self._connections.clear() + if self._thread: + self._thread.join(timeout=5) + log.info("TcpProxy stopped %s:%d", self.listen_host, self.listen_port) + + @property + def active_connections(self): + with self._lock: + return len(self._connections) + + def retarget(self, new_host, new_port): + """Change the backend target for new connections (existing ones keep the old target).""" + self.target_host = new_host + self.target_port = new_port + log.info("TcpProxy %s:%d retargeted to %s:%d", + self.listen_host, self.listen_port, new_host, new_port) + + def drop_connections(self): + """Forcibly close all active connections.""" + with self._lock: + for csock, tsock in list(self._connections): + self._close_pair(csock, tsock) + self._connections.clear() + log.info("TcpProxy %s:%d dropped all connections", self.listen_host, self.listen_port) + + def _run(self): + while self._running: + try: + readable, _, _ = select.select([self._server_sock], [], [], 0.2) + except (ValueError, OSError): + break + for sock in readable: + if sock is self._server_sock: + try: + client_sock, _ = self._server_sock.accept() + except OSError: + continue + self._handle_new_connection(client_sock) + + def _handle_new_connection(self, client_sock, target_host=None, target_port=None): + target_host = target_host or self.target_host + target_port = target_port or self.target_port + try: + target_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + target_sock.connect((target_host, target_port)) + except Exception as e: + log.warning("TcpProxy %s:%d failed to connect to target %s:%d: %s", + self.listen_host, self.listen_port, + target_host, target_port, e) + client_sock.close() + return + + with self._lock: + self._connections.add((client_sock, target_sock)) + self.total_connections += 1 + + t = threading.Thread(target=self._forward_loop, + args=(client_sock, target_sock), + daemon=True) + t.start() + + def _forward_loop(self, client_sock, target_sock): + try: + while self._running: + readable, _, _ = select.select([client_sock, target_sock], [], [], 0.5) + for sock in readable: + data = sock.recv(self.BUF_SIZE) + if not data: + return + if sock is client_sock: + target_sock.sendall(data) + else: + client_sock.sendall(data) + except (OSError, ConnectionResetError, BrokenPipeError): + pass + finally: + with self._lock: + self._connections.discard((client_sock, target_sock)) + self._close_pair(client_sock, target_sock) + + @staticmethod + def _close_pair(csock, tsock): + for s in (csock, tsock): + try: + s.close() + except Exception: + pass + + +class NLBEmulator: + """ + Emulates a Network Load Balancer for a CCM cluster. + + Provides: + - One *discovery port* (round-robin across all live nodes, used as the + driver's ``contact_points``). + - One *per-node port* for each node (dedicated proxy to that node's + native transport port). + + All proxies listen on ``LISTEN_HOST`` (127.254.254.101), an address + outside the CCM node range, simulating a real NLB endpoint. + + Port layout (all ports are OS-assigned by default): + LISTEN_HOST:discovery_port -> round-robin to all live nodes + LISTEN_HOST: -> node1 (127.0.0.1:9042) + LISTEN_HOST: -> node2 (127.0.0.2:9042) + ... + + Automatically creates/removes per-node proxies when nodes are + added/removed so CCM cluster operations are reflected seamlessly. + """ + + LISTEN_HOST = "127.254.254.101" + + def __init__(self, discovery_port=0, + per_node_base=0, + native_port=9042, + node_addresses=None): + self.discovery_port = discovery_port + self.per_node_base = per_node_base + self.native_port = native_port + self._deferred_node_addresses = node_addresses + + self._node_proxies = {} + self._discovery_proxy = None + self._rr_index = 0 + self._lock = threading.Lock() + self._running = False + + def start(self, node_addresses): + """ + Start the NLB with an initial set of node addresses. + + :param node_addresses: dict of node_id -> ip_address, e.g. + {1: "127.0.0.1", 2: "127.0.0.2"} + """ + self._running = True + try: + for node_id, addr in node_addresses.items(): + self._add_node_proxy(node_id, addr) + + first_addr = list(node_addresses.values())[0] + self._discovery_proxy = TcpProxy( + self.LISTEN_HOST, self.discovery_port, + first_addr, self.native_port, + ) + self._discovery_proxy.start() + self.discovery_port = self._discovery_proxy.listen_port + except Exception: + self.stop() + raise + original_handler = self._discovery_proxy._handle_new_connection + + def rr_handler(client_sock): + addrs = self._live_addresses() + if not addrs: + client_sock.close() + return + idx = self._rr_index % len(addrs) + self._rr_index += 1 + addr = addrs[idx] + original_handler(client_sock, target_host=addr, target_port=self.native_port) + + self._discovery_proxy._handle_new_connection = rr_handler + + log.info("NLB started: discovery=%s:%d, %d node proxies", + self.LISTEN_HOST, self.discovery_port, len(self._node_proxies)) + return self + + def __enter__(self): + if not self._running and self._deferred_node_addresses is not None: + self.start(self._deferred_node_addresses) + return self + + def __exit__(self, *args): + self.stop() + + def stop(self): + self._running = False + if self._discovery_proxy: + self._discovery_proxy.stop() + for proxy in self._node_proxies.values(): + proxy.stop() + self._node_proxies.clear() + log.info("NLB stopped") + + def add_node(self, node_id, addr): + self._add_node_proxy(node_id, addr) + + def remove_node(self, node_id): + with self._lock: + proxy = self._node_proxies.pop(node_id, None) + if proxy: + proxy.stop() + log.info("NLB removed node %d", node_id) + + def node_port(self, node_id): + proxy = self._node_proxies.get(node_id) + if proxy: + return proxy.listen_port + return self.per_node_base + node_id + + def get_node_proxy(self, node_id): + return self._node_proxies.get(node_id) + + def total_proxy_connections(self): + return sum(p.total_connections for p in self._node_proxies.values()) + + def active_proxy_connections(self): + return sum(p.active_connections for p in self._node_proxies.values()) + + def drop_all_connections(self): + for proxy in self._node_proxies.values(): + proxy.drop_connections() + if self._discovery_proxy: + self._discovery_proxy.drop_connections() + + def _add_node_proxy(self, node_id, addr): + port = 0 + proxy = TcpProxy(self.LISTEN_HOST, port, addr, self.native_port) + proxy.start() + with self._lock: + self._node_proxies[node_id] = proxy + log.info("NLB added node %d: %s:%d -> %s:%d", + node_id, self.LISTEN_HOST, port, addr, self.native_port) + + def _live_addresses(self): + """IPs of nodes with active proxies.""" + return [p.target_host for p in self._node_proxies.values()] + +def post_client_routes(contact_point, routes): + """ + Post client routes to Scylla's REST API. + + :param contact_point: IP/hostname of a Scylla node (e.g. "127.0.0.1") + :param routes: List of route dicts with keys: connection_id, host_id, address, port + and optionally tls_port + """ + payload = [] + for route in routes: + entry = { + "connection_id": str(route["connection_id"]), + "host_id": str(route["host_id"]), + "address": route["address"], + "port": route["port"], + } + if route.get("tls_port") is not None: + entry["tls_port"] = route["tls_port"] + payload.append(entry) + + url = "http://%s:10000/v2/client-routes" % contact_point + log.info("Posting %d routes to %s", len(payload), url) + data = _json.dumps(payload).encode("utf-8") + req = urllib.request.Request( + url, + data=data, + headers={ + "Content-Type": "application/json", + "Accept": "application/json", + }, + method="POST", + ) + response = urllib.request.urlopen(req) + log.info("Routes posted successfully (status %d)", response.status) + + +def get_host_ids_from_cluster(session): + """ + Build a mapping of rpc_address -> host_id for all nodes in the cluster. + + Uses the driver's metadata rather than querying system.local / system.peers + directly, because those queries can be routed to different coordinators + (system.local returns the coordinator's own info while system.peers omits + the coordinator), leading to a node being missing from the map. + """ + host_id_map = {} + for host in session.cluster.metadata.all_hosts(): + host_id_map[host.address] = host.host_id + return host_id_map + + +def build_routes_for_nlb(connection_id, host_id_map, nlb): + """ + Build routes that direct each host_id through the NLB per-node proxy. + + :param connection_id: Connection ID string + :param host_id_map: dict ip -> uuid host_id (from get_host_ids_from_cluster) + :param nlb: NLBEmulator instance + :return: list of route dicts + """ + routes = [] + for ip, host_id in host_id_map.items(): + node_id = int(ip.split(".")[-1]) + port = nlb.node_port(node_id) + routes.append({ + "connection_id": connection_id, + "host_id": host_id, + "address": NLBEmulator.LISTEN_HOST, + "port": port, + }) + return routes + + +def post_routes_for_nlb(contact_point, connection_id, host_id_map, nlb): + """Build routes for the NLB and POST them via the REST API.""" + routes = build_routes_for_nlb(connection_id, host_id_map, nlb) + post_client_routes(contact_point, routes) + return routes + +def wait_for_routes_visible(session, connection_id, expected_count, timeout=10, poll_interval=0.1): + """ + Poll system.client_routes on **every** node until each one sees at + least *expected_count* rows for *connection_id*. + + ``system.client_routes`` is a node-local table, so routes posted via + the REST API to one node are not guaranteed to be visible on the + others at the same time. This helper ensures they have propagated + everywhere before the test proceeds. + + :param session: an active driver Session (direct, not through NLB) + :param connection_id: the connection_id string to filter on + :param expected_count: how many rows we expect to see per node + :param timeout: maximum seconds to wait + :param poll_interval: seconds between polls + """ + all_hosts = list(session.cluster.metadata.all_hosts()) + deadline = time.time() + timeout + while True: + pending_hosts = [] + for host in all_hosts: + rows = list(session.execute( + "SELECT * FROM system.client_routes WHERE connection_id = %s", + (connection_id,), + host=host, + )) + if len(rows) < expected_count: + pending_hosts.append((host, len(rows))) + if not pending_hosts: + return + if time.time() >= deadline: + details = ", ".join( + "%s: %d" % (h.address, count) for h, count in pending_hosts + ) + raise RuntimeError( + "Timed out waiting for %d routes (connection_id=%s) to appear " + "in system.client_routes on all nodes; pending: %s" + % (expected_count, connection_id, details) + ) + time.sleep(poll_interval) + + +def node_id_from_ip(ip): + """Extract node_id from an IP like '127.0.0.3' -> 3.""" + return int(ip.split(".")[-1]) + + +def assert_routes_via_nlb(test, cluster, nlb, expected_node_ids): + """ + Assert that every host in *expected_node_ids* has its endpoint + resolving through the NLB (correct address and per-node port). + """ + nlb_listen_host = NLBEmulator.LISTEN_HOST + expected_node_ids = set(expected_node_ids) + + seen_node_ids = set() + for host in cluster.metadata.all_hosts(): + ep = host.endpoint + if not isinstance(ep, ClientRoutesEndPoint): + continue + node_id = node_id_from_ip(ep.address) + if node_id not in expected_node_ids: + continue + resolved_addr, resolved_port = ep.resolve() + test.assertEqual( + resolved_addr, nlb_listen_host, + "Node %d endpoint should resolve to NLB address %s, got %s" + % (node_id, nlb_listen_host, resolved_addr), + ) + test.assertEqual( + resolved_port, nlb.node_port(node_id), + "Node %d endpoint should resolve to NLB port %d, got %d" + % (node_id, nlb.node_port(node_id), resolved_port), + ) + seen_node_ids.add(node_id) + test.assertEqual( + seen_node_ids, expected_node_ids, + "Not all expected nodes found in metadata endpoints", + ) + + +def assert_routes_direct(test, cluster, expected_node_ids, direct_port=9042): + """ + Assert that every host in *expected_node_ids* has its endpoint + resolving to the node's own IP on *direct_port*. + """ + expected_node_ids = set(expected_node_ids) + + for host in cluster.metadata.all_hosts(): + ep = host.endpoint + if not isinstance(ep, ClientRoutesEndPoint): + continue + node_id = node_id_from_ip(ep.address) + if node_id not in expected_node_ids: + continue + resolved_addr, resolved_port = ep.resolve() + expected_ip = "127.0.0.%d" % node_id + test.assertEqual( + resolved_addr, expected_ip, + "Node %d endpoint should resolve to direct address %s, got %s" + % (node_id, expected_ip, resolved_addr), + ) + test.assertEqual( + resolved_port, direct_port, + "Node %d endpoint should resolve to direct port %d, got %d" + % (node_id, direct_port, resolved_port), + ) + + +def setup_module(): + os.environ['SCYLLA_EXT_OPTS'] = "--smp 2 --memory 2048M" + use_cluster('test_client_routes', [3], start=True) + +@skip_scylla_version_lt(reason='scylladb/scylladb#26992 - system.client_routes is not yet supported', + scylla_version="2026.1.0") +class TestGetHostPortMapping(unittest.TestCase): + """ + Test _query_all_routes_for_connections and _query_routes_for_change_event + methods with different filtering scenarios. + """ + + @classmethod + def setUpClass(cls): + cls.cluster = TestCluster(client_routes_config=ClientRoutesConfig( + proxies=[ClientRouteProxy("conn_id", "127.0.0.1")])) + cls.session = cls.cluster.connect() + + cls.host_ids = [uuid.uuid4() for _ in range(3)] + cls.connection_ids = [str(uuid.uuid4()) for _ in range(3)] + cls.expected = [] + + for idx, host_id in enumerate(cls.host_ids): + ip = f"127.0.0.{idx + 1}" + for connection_id in cls.connection_ids: + cls.expected.append({ + 'connection_id': connection_id, + 'host_id': host_id, + 'address': ip, + 'port': 9042, + 'tls_port': 9142, + }) + + cls._sort_routes(cls.expected) + post_client_routes(cls.cluster.contact_points[0], cls.expected) + + @classmethod + def tearDownClass(cls): + cls.cluster.shutdown() + + @staticmethod + def _sort_routes(routes): + routes.sort(key=lambda r: (str(r['connection_id']), str(r['host_id']))) + + def _routes_to_dicts(self, routes): + """Convert _Route objects to comparable dicts, adjusting port for ssl_enabled.""" + return [ + { + 'connection_id': route.connection_id, + 'host_id': route.host_id, + 'address': route.address, + 'port': route.port, + } + for route in routes + ] + + def _expected_dicts(self, expected): + """Build expected dicts with tls_port or port based on ssl_enabled.""" + port_key = 'tls_port' if self.cluster._client_routes_handler.ssl_enabled else 'port' + return [ + { + 'connection_id': e['connection_id'], + 'host_id': e['host_id'], + 'address': e['address'], + 'port': e[port_key], + } + for e in expected + ] + + def test_get_all_routes_for_all_connections(self): + """Querying all connection IDs returns every route.""" + cc = self.cluster.control_connection + routes = self.cluster._client_routes_handler._query_all_routes_for_connections( + cc._connection, cc._timeout, self.connection_ids, + ) + got = self._routes_to_dicts(routes) + self._sort_routes(got) + expected = self._expected_dicts(self.expected) + self._sort_routes(expected) + self.assertEqual(got, expected) + + def test_get_routes_for_single_connection(self): + """Querying a single connection ID returns only its routes.""" + cc = self.cluster.control_connection + routes = self.cluster._client_routes_handler._query_all_routes_for_connections( + cc._connection, cc._timeout, [self.connection_ids[0]], + ) + got = self._routes_to_dicts(routes) + self._sort_routes(got) + filtered = [r for r in self.expected + if r['connection_id'] == self.connection_ids[0]] + expected = self._expected_dicts(filtered) + self._sort_routes(expected) + self.assertEqual(got, expected) + + def test_get_routes_for_change_event_all_pairs(self): + """Querying all (connection_id, host_id) pairs returns every route.""" + cc = self.cluster.control_connection + pairs = [(r['connection_id'], r['host_id']) for r in self.expected] + routes = self.cluster._client_routes_handler._query_routes_for_change_event( + cc._connection, cc._timeout, pairs, + ) + got = self._routes_to_dicts(routes) + self._sort_routes(got) + expected = self._expected_dicts(self.expected) + self._sort_routes(expected) + self.assertEqual(got, expected) + + def test_get_routes_for_change_event_single_pair(self): + """Querying a single (connection_id, host_id) pair returns one route.""" + cc = self.cluster.control_connection + target_conn_id = self.connection_ids[0] + target_host_id = self.host_ids[0] + routes = self.cluster._client_routes_handler._query_routes_for_change_event( + cc._connection, cc._timeout, [(target_conn_id, target_host_id)], + ) + got = self._routes_to_dicts(routes) + self._sort_routes(got) + filtered = [r for r in self.expected + if r['connection_id'] == target_conn_id + and r['host_id'] == target_host_id] + expected = self._expected_dicts(filtered) + self._sort_routes(expected) + self.assertEqual(got, expected) + +@skip_scylla_version_lt(reason='scylladb/scylladb#26992 - system.client_routes is not yet supported', + scylla_version="2026.1.0") +class TestPrivateLinkConnectivity(unittest.TestCase): + """ + Verifies the driver connects to all cluster nodes exclusively through + the NLB proxy, never directly. + + Setup: + 1. Start a 3-node CCM cluster (done by setup_module). + 2. Start an NLB emulator with per-node proxies. + 3. Use a direct session to read host_ids, then POST client routes + pointing each host_id at the NLB proxy port. + 4. Create a client-routes-enabled session using the NLB discovery + port as the contact point. + 5. Verify all driver connections go through proxy ports. + """ + + @classmethod + def setUpClass(cls): + cls.direct_cluster = TestCluster() + cls.direct_session = cls.direct_cluster.connect() + cls.host_id_map = get_host_ids_from_cluster(cls.direct_session) + log.info("Host ID map: %s", cls.host_id_map) + + cls.node_addrs = {} + for ip in cls.host_id_map: + node_id = int(ip.split(".")[-1]) + cls.node_addrs[node_id] = ip + + cls.nlb = NLBEmulator() + cls.nlb.start(cls.node_addrs) + + cls.connection_id = str(uuid.uuid4()) + post_routes_for_nlb("127.0.0.1", cls.connection_id, cls.host_id_map, cls.nlb) + wait_for_routes_visible(cls.direct_session, cls.connection_id, len(cls.host_id_map)) + + @classmethod + def tearDownClass(cls): + cls.direct_cluster.shutdown() + cls.nlb.stop() + + def _make_client_routes_cluster(self, **extra_kwargs): + """Create a Cluster configured with client-routes pointing at the NLB.""" + return Cluster( + contact_points=[NLBEmulator.LISTEN_HOST], + port=self.nlb.discovery_port, + client_routes_config=ClientRoutesConfig( + proxies=[ClientRouteProxy(self.connection_id, NLBEmulator.LISTEN_HOST)], + ), + load_balancing_policy=RoundRobinPolicy(), + **extra_kwargs, + ) + + def test_all_connections_through_proxy(self): + """Every pool connection must go through the NLB proxy, not directly.""" + with self._make_client_routes_cluster() as cluster: + session = cluster.connect(wait_for_all_pools=True) + + for _ in range(50): + session.execute("SELECT key FROM system.local") + + pool_state = session.get_pool_state() + self.assertEqual(len(pool_state), len(self.node_addrs), + "Driver should have pools for all nodes") + + for host, state in pool_state.items(): + node_id = node_id_from_ip(host.address) + proxy = self.nlb.get_node_proxy(node_id) + self.assertIsNotNone(proxy, f"No proxy for node {node_id}") + open_count = state['open_count'] + self.assertGreaterEqual( + proxy.total_connections, open_count, + f"Node {node_id} proxy saw {proxy.total_connections} " + f"connections but pool has {open_count} open — " + f"some connections bypassed the proxy") + + assert_routes_via_nlb(self, cluster, self.nlb, + self.node_addrs.keys()) + + def test_queries_succeed_through_proxy(self): + """Queries should work normally through the proxy.""" + with self._make_client_routes_cluster() as cluster: + session = cluster.connect() + session.execute( + "CREATE KEYSPACE IF NOT EXISTS test_cr_ks " + "WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 3}" + ) + session.execute( + "CREATE TABLE IF NOT EXISTS test_cr_ks.t (k int PRIMARY KEY, v text)" + ) + session.execute("INSERT INTO test_cr_ks.t (k, v) VALUES (1, 'hello')") + row = session.execute("SELECT v FROM test_cr_ks.t WHERE k = 1").one() + self.assertEqual(row.v, "hello") + + assert_routes_via_nlb(self, cluster, self.nlb, + self.node_addrs.keys()) + + def test_connection_recovery_after_proxy_drop(self): + """ + After the proxy drops all connections, the driver should reconnect + (still through the proxy). + """ + with self._make_client_routes_cluster() as cluster: + session = cluster.connect(wait_for_all_pools=True) + session.execute("SELECT key FROM system.local") + + assert_routes_via_nlb(self, cluster, self.nlb, + self.node_addrs.keys()) + + self.nlb.drop_all_connections() + + def query_ok(): + session.execute("SELECT key FROM system.local") + + wait_until_not_raised(query_ok, 1, 30) + + assert_routes_via_nlb(self, cluster, self.nlb, + self.node_addrs.keys()) + +@skip_scylla_version_lt(reason='scylladb/scylladb#26992 - system.client_routes is not yet supported', + scylla_version="2026.1.0") +class TestDynamicRouteUpdates(unittest.TestCase): + """ + Verify that when routes are updated (e.g. port changes), the driver + picks up the new routes and reconnects through the new proxy ports + after existing connections are dropped. + """ + + @classmethod + def setUpClass(cls): + cls.direct_cluster = TestCluster() + cls.direct_session = cls.direct_cluster.connect() + cls.host_id_map = get_host_ids_from_cluster(cls.direct_session) + + cls.node_addrs = {} + for ip in cls.host_id_map: + node_id = int(ip.split(".")[-1]) + cls.node_addrs[node_id] = ip + + cls.connection_id = str(uuid.uuid4()) + + @classmethod + def tearDownClass(cls): + cls.direct_cluster.shutdown() + + def test_route_update_causes_reconnect_to_new_port(self): + """ + 1. Start NLB v1, post routes -> driver connects through v1 ports. + 2. Start NLB v2 on different ports, post new routes. + 3. Drop v1 connections. + 4. Driver should reconnect through v2 ports. + """ + with NLBEmulator( + node_addresses=self.node_addrs, + ) as nlb_v1, NLBEmulator( + node_addresses=self.node_addrs, + ) as nlb_v2: + post_routes_for_nlb("127.0.0.1", self.connection_id, + self.host_id_map, nlb_v1) + wait_for_routes_visible(self.direct_session, self.connection_id, len(self.host_id_map)) + + with Cluster( + contact_points=[NLBEmulator.LISTEN_HOST], + port=nlb_v1.discovery_port, + client_routes_config=ClientRoutesConfig( + proxies=[ClientRouteProxy(self.connection_id, NLBEmulator.LISTEN_HOST)], + ), + load_balancing_policy=RoundRobinPolicy(), + ) as cluster: + session = cluster.connect(wait_for_all_pools=True) + session.execute("SELECT key FROM system.local") + + for node_id in self.node_addrs: + self.assertGreater( + nlb_v1.get_node_proxy(node_id).total_connections, 0) + assert_routes_via_nlb(self, cluster, nlb_v1, + self.node_addrs.keys()) + + post_routes_for_nlb("127.0.0.1", self.connection_id, + self.host_id_map, nlb_v2) + time.sleep(2) # let CLIENT_ROUTES_CHANGE propagate + + # Stop v1 per-node proxies entirely so v1 ports become + # unreachable, forcing the driver to reconnect through v2. + # (Merely dropping connections is insufficient because v1 + # proxies would still accept new connections before the + # route update propagates.) + for node_id in list(self.node_addrs.keys()): + nlb_v1.remove_node(node_id) + + def all_nodes_via_v2(): + session.execute("SELECT key FROM system.local") + for nid in self.node_addrs: + assert nlb_v2.get_node_proxy(nid).total_connections > 0, \ + "NLB v2 node %d proxy has no connections yet" % nid + + wait_until_not_raised(all_nodes_via_v2, 1, 30) + + assert_routes_via_nlb(self, cluster, nlb_v2, + self.node_addrs.keys()) + + +def _generate_ssl_certs(cert_dir, node_ips): + """ + Generate test SSL certificates with SANs covering the given node IPs. + + File names follow CCM's ``ScyllaCluster.enable_ssl()`` convention so the + resulting directory can be passed directly to ``enable_ssl(cert_dir, ...)``. + + Creates: + - ca.key / ca.crt: self-signed CA + - ccm_node.key / ccm_node.pem: server cert signed by CA with SANs for all node_ips + + :param cert_dir: directory to write files into (must exist) + :param node_ips: list of IP strings to include as SANs (e.g. ["127.0.0.1", "127.0.0.2"]) + """ + if shutil.which("openssl") is None: + raise unittest.SkipTest("openssl not found on PATH; skipping SSL cert generation") + + san_cnf = os.path.join(cert_dir, "san.cnf") + san_value = ",".join("IP:%s" % ip for ip in node_ips) + with open(san_cnf, "w") as f: + f.write("subjectAltName=%s\n" % san_value) + + def _run(cmd): + result = subprocess.run(cmd, cwd=cert_dir, capture_output=True, text=True) + if result.returncode != 0: + raise RuntimeError("Command failed: %s\n%s" % (" ".join(cmd), result.stderr)) + + _run(["openssl", "req", "-x509", "-newkey", "rsa:2048", + "-keyout", "ca.key", "-out", "ca.crt", + "-days", "1", "-nodes", "-subj", "/CN=Test CA"]) + + _run(["openssl", "req", "-newkey", "rsa:2048", + "-keyout", "ccm_node.key", "-out", "ccm_node.csr", + "-nodes", "-subj", "/CN=Test Server"]) + + _run(["openssl", "x509", "-req", + "-in", "ccm_node.csr", "-CA", "ca.crt", "-CAkey", "ca.key", + "-CAcreateserial", "-out", "ccm_node.pem", + "-days", "1", "-extfile", "san.cnf"]) + + log.info("Generated SSL certs in %s with SANs: %s", cert_dir, san_value) + + +@skip_scylla_version_lt(reason='scylladb/scylladb#26992 - system.client_routes is not yet supported', + scylla_version="2026.1.0") +class TestMixedDirectAndNlbConnections(unittest.TestCase): + """ + Verify the cluster works when some nodes are accessed through the NLB + proxy and others are accessed directly (no route posted, falls back + to the default endpoint). + """ + + @classmethod + def setUpClass(cls): + cls.direct_cluster = TestCluster() + cls.direct_session = cls.direct_cluster.connect() + cls.host_id_map = get_host_ids_from_cluster(cls.direct_session) + + cls.node_addrs = {} + for ip in cls.host_id_map: + node_id = int(ip.split(".")[-1]) + cls.node_addrs[node_id] = ip + + cls.connection_id = str(uuid.uuid4()) + + @classmethod + def tearDownClass(cls): + cls.direct_cluster.shutdown() + + def test_mixed_direct_and_nlb_connections(self): + """ + Post routes for only a subset of nodes (through NLB proxy). + Remaining nodes have no route and fall back to direct connections. + Queries should work through both paths. + """ + proxied_node_id = min(self.node_addrs.keys()) + proxied_ip = self.node_addrs[proxied_node_id] + + with NLBEmulator( + node_addresses={proxied_node_id: proxied_ip}, + ) as nlb: + proxied_host_id = self.host_id_map[proxied_ip] + routes = [{ + "connection_id": self.connection_id, + "host_id": proxied_host_id, + "address": NLBEmulator.LISTEN_HOST, + "port": nlb.node_port(proxied_node_id), + }] + post_client_routes("127.0.0.1", routes) + time.sleep(1) + + with Cluster( + contact_points=["127.0.0.1"], + client_routes_config=ClientRoutesConfig( + proxies=[ClientRouteProxy(self.connection_id, NLBEmulator.LISTEN_HOST)], + ), + load_balancing_policy=RoundRobinPolicy(), + ) as cluster: + session = cluster.connect(wait_for_all_pools=True) + + for _ in range(50): + session.execute("SELECT key FROM system.local") + + assert_routes_via_nlb(self, cluster, nlb, + [proxied_node_id]) + + direct_node_ids = set(self.node_addrs.keys()) - {proxied_node_id} + assert_routes_direct(self, cluster, direct_node_ids) + + proxy = nlb.get_node_proxy(proxied_node_id) + self.assertGreater(proxy.total_connections, 0, + "Proxied node should have connections through NLB") + + +@skip_scylla_version_lt(reason='scylladb/scylladb#26992 - system.client_routes is not yet supported', + scylla_version="2026.1.0") +class TestSslThroughNlb(unittest.TestCase): + """ + Verify SSL with check_hostname=False works through the NLB proxy. + + When using client routes, connections go through NLB proxies whose + addresses won't match server certificates, so hostname verification + must be disabled. Certificate chain validation (verify_mode=CERT_REQUIRED) + is still active — only hostname matching is skipped. + + The driver raises ValueError at Cluster init time if check_hostname=True + is used with client_routes_config. + """ + + @classmethod + def setUpClass(cls): + cls.direct_cluster = TestCluster() + cls.direct_session = cls.direct_cluster.connect() + cls.host_id_map = get_host_ids_from_cluster(cls.direct_session) + cls.direct_cluster.shutdown() + + cls.node_addrs = {} + for ip in cls.host_id_map: + node_id = int(ip.split(".")[-1]) + cls.node_addrs[node_id] = ip + + cls.connection_id = str(uuid.uuid4()) + + cls.cert_dir = tempfile.mkdtemp(prefix="client-routes-ssl-") + cert_ips = list(cls.node_addrs.values()) + _generate_ssl_certs(cls.cert_dir, cert_ips) + + cls.ccm_cluster = get_cluster() + cls.ccm_cluster.stop() + cls.ccm_cluster.set_configuration_options({ + 'client_encryption_options': { + 'enabled': True, + 'certificate': os.path.join(cls.cert_dir, "ccm_node.pem"), + 'keyfile': os.path.join(cls.cert_dir, "ccm_node.key"), + } + }) + cls.ccm_cluster.start(wait_for_binary_proto=True) + + @classmethod + def tearDownClass(cls): + cls.ccm_cluster.stop() + cls.ccm_cluster.set_configuration_options({ + 'client_encryption_options': { + 'enabled': False, + } + }) + cls.ccm_cluster.start(wait_for_binary_proto=True) + + shutil.rmtree(cls.cert_dir, ignore_errors=True) + + def test_ssl_without_hostname_verification_through_nlb(self): + """ + Connect through NLB with SSL but check_hostname=False. + + When using client routes, connections go through NLB proxies + whose addresses won't match server certificates, so hostname + verification must be disabled. Certificate chain validation + (verify_mode=CERT_REQUIRED) is still active. + """ + with NLBEmulator( + node_addresses=self.node_addrs, + ) as nlb: + routes = build_routes_for_nlb( + self.connection_id, self.host_id_map, nlb, + ) + for route in routes: + route["tls_port"] = route["port"] + post_client_routes("127.0.0.1", routes) + + ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ssl_ctx.check_hostname = False + ssl_ctx.load_verify_locations(os.path.join(self.cert_dir, 'ca.crt')) + + self.assertFalse(ssl_ctx.check_hostname, + "check_hostname must be False for this test") + self.assertEqual(ssl_ctx.verify_mode, ssl.CERT_REQUIRED, + "verify_mode must be CERT_REQUIRED") + + def routes_visible(): + with TestCluster( + contact_points=["127.0.0.1"], + ssl_context=ssl_ctx, + ) as c: + session = c.connect() + rs = session.execute( + "SELECT * FROM system.client_routes " + "WHERE connection_id = %s ALLOW FILTERING", + (self.connection_id,) + ) + return len(list(rs)) >= len(self.host_id_map) + + wait_until_not_raised( + lambda: self.assertTrue(routes_visible()), + 0.5, 10, + ) + + with Cluster( + contact_points=[NLBEmulator.LISTEN_HOST], + port=nlb.discovery_port, + ssl_context=ssl_ctx, + client_routes_config=ClientRoutesConfig( + proxies=[ClientRouteProxy(self.connection_id, NLBEmulator.LISTEN_HOST)], + ), + load_balancing_policy=RoundRobinPolicy(), + ) as cluster: + session = cluster.connect(wait_for_all_pools=True) + + for _ in range(20): + row = session.execute( + "SELECT release_version FROM system.local" + ).one() + self.assertIsNotNone(row) + + assert_routes_via_nlb(self, cluster, nlb, + self.node_addrs.keys()) + + def test_ssl_with_hostname_verification_raises_error(self): + """ + Verify that Cluster raises ValueError when client_routes_config + is used with SSL hostname verification enabled. + """ + ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ssl_ctx.load_verify_locations(os.path.join(self.cert_dir, 'ca.crt')) + self.assertTrue(ssl_ctx.check_hostname) + + with self.assertRaises(ValueError) as cm: + Cluster( + contact_points=[NLBEmulator.LISTEN_HOST], + ssl_context=ssl_ctx, + client_routes_config=ClientRoutesConfig( + proxies=[ClientRouteProxy("test-id", NLBEmulator.LISTEN_HOST)], + ), + ) + self.assertIn("check_hostname", str(cm.exception)) + +@skip_scylla_version_lt(reason='scylladb/scylladb#26992 - system.client_routes is not yet supported', + scylla_version="2026.1.0") +class TestFullNodeReplacementThroughNlb(unittest.TestCase): + """ + End-to-end test: creates a session through an NLB proxy with client routes, + scales the cluster up, then decommissions original nodes, verifying the + session survives the full node replacement. + + This test is destructive — it modifies the CCM cluster topology by + bootstrapping new nodes and decommissioning original ones. It uses + its own CCM cluster so it cannot interfere with other tests. + """ + + @classmethod + def setUpClass(cls): + os.environ['SCYLLA_EXT_OPTS'] = "--smp 2 --memory 2048M" + use_cluster('test_client_routes_replacement', [3], start=True) + + cls.direct_cluster = TestCluster() + cls.direct_session = cls.direct_cluster.connect() + cls.host_id_map = get_host_ids_from_cluster(cls.direct_session) + + cls.node_addrs = {} + for ip in cls.host_id_map: + node_id = int(ip.split(".")[-1]) + cls.node_addrs[node_id] = ip + + cls.connection_id = str(uuid.uuid4()) + + @classmethod + def tearDownClass(cls): + cls.direct_cluster.shutdown() + + def test_should_survive_full_node_replacement_through_nlb(self): + """ + 1. Start with 3 nodes behind the NLB + 2. Bootstrap 2 new nodes, add to NLB, update routes + 3. Decommission the original 3 nodes one-by-one, updating NLB/routes + 4. Verify the session survives with only new nodes + """ + original_node_ids = sorted(self.node_addrs.keys()) + with NLBEmulator( + node_addresses=self.node_addrs, + ) as nlb: + # ---- Stage 1: Set up NLB for initial nodes ---- + log.info("Stage 1: Setting up NLB for %d initial nodes", len(original_node_ids)) + + post_routes_for_nlb("127.0.0.1", self.connection_id, self.host_id_map, nlb) + wait_for_routes_visible(self.direct_session, self.connection_id, len(self.host_id_map)) + + # ---- Stage 2: Create session through NLB ---- + log.info("Stage 2: Creating session through NLB") + with Cluster( + contact_points=[NLBEmulator.LISTEN_HOST], + port=nlb.discovery_port, + client_routes_config=ClientRoutesConfig( + proxies=[ClientRouteProxy(self.connection_id, NLBEmulator.LISTEN_HOST)], + ), + load_balancing_policy=RoundRobinPolicy(), + ) as cluster: + session = cluster.connect(wait_for_all_pools=True) + self._assert_query_works(session) + + handler = cluster._client_routes_handler + self.assertIsNotNone(handler) + + assert_routes_via_nlb(self, cluster, nlb, + original_node_ids) + log.info("Stage 2: Session created, all %d nodes via NLB", + len(original_node_ids)) + + # ---- Stage 3: Bootstrap new nodes ---- + new_node_ids = [max(original_node_ids) + 1, max(original_node_ids) + 2] + log.info("Stage 3: Adding nodes %s", new_node_ids) + ccm_cluster = get_cluster() + + for node_id in new_node_ids: + self._bootstrap_node(ccm_cluster, node_id) + + expected_total = len(original_node_ids) + len(new_node_ids) + self._wait_for_condition( + lambda: len(cluster.metadata.all_hosts()) >= expected_total, + timeout_seconds=60, + description="%d nodes in metadata" % expected_total, + ) + + for node_id in new_node_ids: + nlb.add_node(node_id, "127.0.0.%d" % node_id) + + all_host_ids = get_host_ids_from_cluster(session) + log.info("All host IDs after expansion: %s", all_host_ids) + post_routes_for_nlb("127.0.0.1", self.connection_id, all_host_ids, nlb) + + handler.initialize( + cluster.control_connection._connection, + cluster.control_connection._timeout) + + self._wait_for_condition( + lambda: sum(1 for h in cluster.metadata.all_hosts() if h.is_up) >= expected_total, + timeout_seconds=60, + description="all %d nodes up" % expected_total, + ) + + self._assert_query_works(session) + + all_node_ids = set(original_node_ids) | set(new_node_ids) + assert_routes_via_nlb(self, cluster, nlb, all_node_ids) + log.info("Stage 3: All %d nodes via NLB after expansion", + len(all_node_ids)) + + # ---- Stage 4: Decommission original nodes ---- + log.info("Stage 4: Decommissioning original nodes %s", original_node_ids) + + remaining_node_ids = set(all_node_ids) + remaining_host_ids = dict(all_host_ids) + for node_id in original_node_ids: + log.info("Decommissioning node %d", node_id) + get_node(node_id).decommission() + nlb.remove_node(node_id) + remaining_node_ids.discard(node_id) + + ip = "127.0.0.%d" % node_id + remaining_host_ids.pop(ip, None) + + surviving_ips = list(remaining_host_ids.keys()) + if surviving_ips: + post_routes_for_nlb( + surviving_ips[0], self.connection_id, + remaining_host_ids, nlb, + ) + + expected_remaining = expected_total - (original_node_ids.index(node_id) + 1) + self._wait_for_condition( + lambda er=expected_remaining: ( + len(cluster.metadata.all_hosts()) <= er + and self._query_succeeds(session) + ), + timeout_seconds=60, + description="node %d decommissioned" % node_id, + ) + + # Reload routes after the control connection has + # re-established itself (the decommission may have + # killed the old control connection). + handler.initialize( + cluster.control_connection._connection, + cluster.control_connection._timeout) + + assert_routes_via_nlb(self, cluster, nlb, + remaining_node_ids) + log.info("Node %d decommissioned, %d nodes still via NLB", + node_id, len(remaining_node_ids)) + + # ---- Stage 5: Verify with only new nodes ---- + log.info("Stage 5: Verifying session works with only new nodes %s", new_node_ids) + self._assert_query_works(session) + + hosts = cluster.metadata.all_hosts() + self.assertEqual( + len(hosts), len(new_node_ids), + "Expected %d hosts, got %d" % (len(new_node_ids), len(hosts)) + ) + + for _ in range(10): + self._assert_query_works(session) + + assert_routes_via_nlb(self, cluster, nlb, new_node_ids) + log.info("PASS: Full node replacement, all %d new nodes via NLB", + len(new_node_ids)) + + def _assert_query_works(self, session): + rs = session.execute("SELECT release_version FROM system.local WHERE key='local'") + row = rs.one() + self.assertIsNotNone(row, "Query via NLB should return a result") + + def _query_succeeds(self, session): + try: + self._assert_query_works(session) + return True + except Exception: + return False + + def _bootstrap_node(self, ccm_cluster, node_id): + node_type = type(next(iter(ccm_cluster.nodes.values()))) + ip = "127.0.0.%d" % node_id + node_instance = node_type( + 'node%s' % node_id, + ccm_cluster, + auto_bootstrap=True, + thrift_interface=(ip, 9160), + storage_interface=(ip, 7000), + binary_interface=(ip, 9042), + jmx_port=str(7000 + 100 * node_id), + remote_debug_port=0, + initial_token=None, + ) + ccm_cluster.add(node_instance, is_seed=False) + node_instance.start(wait_for_binary_proto=True, wait_other_notice=True) + wait_for_node_socket(node_instance, 120) + log.info("Node %d bootstrapped successfully", node_id) + + @staticmethod + def _wait_for_condition(predicate, timeout_seconds, poll_interval=2, description="condition"): + deadline = time.time() + timeout_seconds + while time.time() < deadline: + if predicate(): + return True + time.sleep(poll_interval) + raise AssertionError( + "Timed out waiting for %s after %d seconds" % (description, timeout_seconds) + ) diff --git a/tests/integration/standard/test_cluster.py b/tests/integration/standard/test_cluster.py index bf62f5df48..aab4131739 100644 --- a/tests/integration/standard/test_cluster.py +++ b/tests/integration/standard/test_cluster.py @@ -1121,8 +1121,7 @@ def test_stale_connections_after_shutdown(self): """ for _ in range(10): with TestCluster(protocol_version=3) as cluster: - cluster.connect().execute("SELECT * FROM system_schema.keyspaces") - time.sleep(1) + cluster.connect(wait_for_all_pools=True).execute("SELECT * FROM system_schema.keyspaces") with TestCluster(protocol_version=3) as cluster: session = cluster.connect() diff --git a/tests/integration/standard/test_connection.py b/tests/integration/standard/test_connection.py index 630e5e6ba0..df0f568c2c 100644 --- a/tests/integration/standard/test_connection.py +++ b/tests/integration/standard/test_connection.py @@ -32,6 +32,7 @@ from tests import is_monkey_patched from tests.integration import use_singledc, get_node, CASSANDRA_IP, local, \ requiresmallclockgranularity, greaterthancass20, TestCluster +from tests.util import wait_until try: import cassandra.io.asyncorereactor @@ -140,9 +141,10 @@ def test_heart_beat_timeout(self): # Wait for connections associated with this host go away self.wait_for_no_connections(host, self.cluster) - # Wait to seconds for the driver to be notified - time.sleep(2) - assert test_listener.host_down + # Wait for the driver to detect the host is down + wait_until( + lambda: test_listener.host_down, + delay=0.5, max_attempts=20) # Resume paused node finally: node.resume() diff --git a/tests/integration/standard/test_metrics.py b/tests/integration/standard/test_metrics.py index 7b502d91c3..7ebdded141 100644 --- a/tests/integration/standard/test_metrics.py +++ b/tests/integration/standard/test_metrics.py @@ -25,6 +25,7 @@ from cassandra.cluster import NoHostAvailable, ExecutionProfile, EXEC_PROFILE_DEFAULT from tests.integration import get_cluster, get_node, use_singledc, execute_until_pass, TestCluster +from tests.util import wait_until, wait_until_not_raised from cassandra import metrics from tests.integration import BasicSharedKeyspaceUnitTestCaseRF3WM, BasicExistingKeyspaceUnitTestCase, local @@ -75,8 +76,10 @@ def test_connection_error(self): self.session.execute(query) finally: get_cluster().start(wait_for_binary_proto=True, wait_other_notice=True) - # Give some time for the cluster to come back up, for the next test - time.sleep(5) + # Wait for the cluster to come back up for the next test + wait_until_not_raised( + lambda: self.session.execute("SELECT key FROM system.local WHERE key='local'"), + delay=0.5, max_attempts=30) assert self.cluster.metrics.stats.connection_errors > 0 @@ -156,7 +159,10 @@ def test_unavailable(self): # Sometimes this commands continues with the other nodes having not noticed # 1 is down, and a Timeout error is returned instead of an Unavailable get_node(1).stop(wait=True, wait_other_notice=True) - time.sleep(5) + wait_until( + lambda: not self.cluster.metadata.get_host('127.0.0.1') or + not self.cluster.metadata.get_host('127.0.0.1').is_up, + delay=0.5, max_attempts=30) try: # Test write query = SimpleStatement("INSERT INTO test (k, v) VALUES (2, 2)", consistency_level=ConsistencyLevel.ALL) @@ -171,8 +177,10 @@ def test_unavailable(self): assert self.cluster.metrics.stats.unavailables == 2 finally: get_node(1).start(wait_other_notice=True, wait_for_binary_proto=True) - # Give some time for the cluster to come back up, for the next test - time.sleep(5) + # Wait for the cluster to come back up for the next test + wait_until_not_raised( + lambda: self.session.execute("SELECT key FROM system.local WHERE key='local'"), + delay=0.5, max_attempts=30) self.cluster.shutdown() diff --git a/tests/integration/standard/test_query.py b/tests/integration/standard/test_query.py index 9cebc22b05..f9d3dc26bc 100644 --- a/tests/integration/standard/test_query.py +++ b/tests/integration/standard/test_query.py @@ -29,7 +29,7 @@ USE_CASS_EXTERNAL, greaterthanorequalcass40, TestCluster, xfail_scylla from tests import notwindows from tests.integration import greaterthanorequalcass30, get_node -from tests.util import assertListEqual +from tests.util import assertListEqual, wait_until import time import random @@ -1571,9 +1571,10 @@ def test_reprepare_after_host_is_down(self): get_node(1).start(wait_for_binary_proto=True, wait_other_notice=True) - # We wait for cluster._prepare_all_queries to be called - time.sleep(5) - assert 1 == mock_handler.get_message_count('debug', 'Preparing all known prepared statements') + # Wait for cluster._prepare_all_queries to be called + wait_until( + lambda: mock_handler.get_message_count('debug', 'Preparing all known prepared statements') >= 1, + delay=0.5, max_attempts=20) results = self.session.execute(prepared_statement, (1,), execution_profile="only_first") assert results.one() == (1, ) diff --git a/tests/integration/standard/test_shard_aware.py b/tests/integration/standard/test_shard_aware.py index 48d1aa3609..2d764d681e 100644 --- a/tests/integration/standard/test_shard_aware.py +++ b/tests/integration/standard/test_shard_aware.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -import time import random from subprocess import run import logging @@ -27,6 +26,7 @@ from cassandra import OperationTimedOut, ConsistencyLevel from tests.integration import use_cluster, get_node, PROTOCOL_VERSION +from tests.util import wait_until_not_raised LOGGER = logging.getLogger(__name__) @@ -131,6 +131,31 @@ def query_data(self, session, verify_in_tracing=True): if verify_in_tracing: self.verify_same_shard_in_tracing(results, "shard 0") + def _assert_blocked_node_disconnected(self, node_ip_address, node_port): + control_connection = self.cluster.control_connection + active_control_connection = control_connection._connection if control_connection else None + if active_control_connection and \ + active_control_connection.endpoint.address == node_ip_address and \ + active_control_connection.endpoint.port == node_port: + assert active_control_connection.is_closed or active_control_connection.is_defunct + + pools = getattr(self.session, '_pools', None) or {} + for host, pool in pools.items(): + if host.endpoint.address != node_ip_address or host.endpoint.port != node_port: + continue + + open_connections = [ + connection for connection in pool._connections.values() + if not (connection.is_closed or connection.is_defunct) + ] + assert not open_connections + + pending_connections = [ + connection for connection in pool._pending_connections + if not (connection.is_closed or connection.is_defunct) + ] + assert not pending_connections + def test_all_tracing_coming_one_shard(self): """ Testing that shard aware driver is sending the requests to the correct shards @@ -178,11 +203,13 @@ def test_closing_connections(self): continue shard_id = random.choice(list(pool._connections.keys())) pool._connections.get(shard_id).close() - time.sleep(5) - self.query_data(self.session, verify_in_tracing=False) + wait_until_not_raised( + lambda: self.query_data(self.session, verify_in_tracing=False), + delay=0.5, max_attempts=30) - time.sleep(10) - self.query_data(self.session) + wait_until_not_raised( + lambda: self.query_data(self.session), + delay=0.5, max_attempts=60) @pytest.mark.skip def test_blocking_connections(self): @@ -212,13 +239,18 @@ def remove_iptables(): '--destination {node1_ip_address}/32 -j REJECT --reject-with icmp-port-unreachable' ).format(node1_ip_address=node1_ip_address, node1_port=node1_port).split(' ') ) - time.sleep(5) + + wait_until_not_raised( + lambda: self._assert_blocked_node_disconnected(node1_ip_address, node1_port), + delay=0.1, + max_attempts=50) try: self.query_data(self.session, verify_in_tracing=False) except OperationTimedOut: pass remove_iptables() - time.sleep(5) - self.query_data(self.session, verify_in_tracing=False) + wait_until_not_raised( + lambda: self.query_data(self.session, verify_in_tracing=False), + delay=0.5, max_attempts=30) self.query_data(self.session) diff --git a/tests/integration/standard/test_tablets.py b/tests/integration/standard/test_tablets.py index d9439e5c2c..d969140339 100644 --- a/tests/integration/standard/test_tablets.py +++ b/tests/integration/standard/test_tablets.py @@ -1,11 +1,10 @@ -import time - import pytest from cassandra.cluster import Cluster, EXEC_PROFILE_DEFAULT, ExecutionProfile from cassandra.policies import ConstantReconnectionPolicy, RoundRobinPolicy, TokenAwarePolicy from tests.integration import PROTOCOL_VERSION, use_cluster, get_cluster +from tests.util import wait_until from tests.unit.test_host_connection_pool import LOGGER @@ -28,7 +27,7 @@ def teardown_class(cls): cls.cluster.shutdown() def verify_hosts_in_tracing(self, results, expected): - traces = results.get_query_trace() + traces = results.get_query_trace(max_wait_sec=10) events = traces.events host_set = set() for event in events: @@ -54,7 +53,7 @@ def get_tablet_record(self, query): return metadata._tablets.get_tablet_for_key(query.keyspace, query.table, metadata.token_map.token_class.from_key(query.routing_key)) def verify_same_shard_in_tracing(self, results): - traces = results.get_query_trace() + traces = results.get_query_trace(max_wait_sec=10) events = traces.events shard_set = set() for event in events: @@ -212,7 +211,10 @@ def test_tablets_invalidation_drop_ks(self): def drop_ks(_): # Drop and recreate ks and table to trigger tablets invalidation self.create_ks_and_cf(self.cluster.connect()) - time.sleep(3) + # Wait for tablet metadata to be refreshed + wait_until( + lambda: 'test1' in self.cluster.metadata.keyspaces, + delay=0.5, max_attempts=20) self.run_tablets_invalidation_test(drop_ks) @@ -233,7 +235,12 @@ def decommission_non_cc_node(rec): break else: assert False, "failed to find node to decommission" - time.sleep(10) + # Wait for decommission to complete and metadata to update + wait_until( + lambda: len([h for h in self.cluster.metadata.all_hosts() if h.is_up]) < 3, + delay=1, max_attempts=60) + # Tablet metadata invalidation may take additional time to propagate; + # run_tablets_invalidation_test will poll for the expected result. self.run_tablets_invalidation_test(decommission_non_cc_node) @@ -257,5 +264,7 @@ def run_tablets_invalidation_test(self, invalidate): invalidate(rec) - # Check if tablets information was purged - assert self.get_tablet_record(bound) is None, "tablet was not deleted, invalidation did not work" + # Wait for tablets information to be purged (invalidation is async) + wait_until( + lambda: self.get_tablet_record(bound) is None, + delay=0.5, max_attempts=20) diff --git a/tests/integration/upgrade/__init__.py b/tests/integration/upgrade/__init__.py index a1c751bcbd..fab6fed34a 100644 --- a/tests/integration/upgrade/__init__.py +++ b/tests/integration/upgrade/__init__.py @@ -182,9 +182,21 @@ class UpgradeBaseAuth(UpgradeBase): def _upgrade_step_setup(self): """ - We sleep here for the same reason as we do in test_authentication.py: - there seems to be some race, with some versions of C* taking longer to - get the auth (and default user) setup. Sleep here to give it a chance + Wait for PasswordAuthenticator to finish initializing (creating the + default superuser). Poll by attempting to authenticate rather than + using a fixed sleep. """ super(UpgradeBaseAuth, self)._upgrade_step_setup() - time.sleep(10) + + from cassandra.auth import PlainTextAuthProvider + from tests.util import wait_until_not_raised + + def _check_auth_ready(): + c = Cluster(auth_provider=PlainTextAuthProvider('cassandra', 'cassandra')) + try: + s = c.connect() + s.execute("SELECT * FROM system.local WHERE key='local'") + finally: + c.shutdown() + + wait_until_not_raised(_check_auth_ready, delay=1, max_attempts=30) diff --git a/tests/integration/upgrade/test_upgrade.py b/tests/integration/upgrade/test_upgrade.py index fec9a38604..45827723b3 100644 --- a/tests/integration/upgrade/test_upgrade.py +++ b/tests/integration/upgrade/test_upgrade.py @@ -19,11 +19,22 @@ from cassandra.cluster import ConsistencyLevel, Cluster, DriverException, ExecutionProfile from cassandra.policies import ConstantSpeculativeExecutionPolicy from tests.integration.upgrade import UpgradeBase, UpgradeBaseAuth, UpgradePath, upgrade_paths +from tests.util import wait_until import unittest import pytest +def _wait_for_control_connection(cluster_driver, timeout=60): + """Wait for the driver's control connection to be established.""" + wait_until( + lambda: cluster_driver.control_connection._connection is not None + and not cluster_driver.control_connection._connection.is_closed, + delay=1, + max_attempts=timeout, + ) + + # Previous Cassandra upgrade two_to_three_path = upgrade_paths([ UpgradePath("2.2.9-3.11", {"version": "2.2.9"}, {"version": "3.11.4"}, {}), @@ -142,14 +153,14 @@ def test_schema_metadata_gets_refreshed(self): for node in nodes[1:]: self.upgrade_node(node) # Wait for the control connection to reconnect - time.sleep(20) + _wait_for_control_connection(self.cluster_driver) with pytest.raises(DriverException): self.cluster_driver.refresh_schema_metadata(max_schema_agreement_wait=10) self.upgrade_node(nodes[0]) # Wait for the control connection to reconnect - time.sleep(20) + _wait_for_control_connection(self.cluster_driver) self.cluster_driver.refresh_schema_metadata(max_schema_agreement_wait=40) assert original_meta != self.cluster_driver.metadata.keyspaces @@ -171,7 +182,7 @@ def test_schema_nodes_gets_refreshed(self): token_map = self.cluster_driver.metadata.token_map self.upgrade_node(node) # Wait for the control connection to reconnect - time.sleep(20) + _wait_for_control_connection(self.cluster_driver) self.cluster_driver.refresh_nodes(force_token_rebuild=True) self._assert_same_token_map(token_map, self.cluster_driver.metadata.token_map) diff --git a/tests/unit/io/test_twistedreactor.py b/tests/unit/io/test_twistedreactor.py index 54abe884ae..8ba9ca5b1d 100644 --- a/tests/unit/io/test_twistedreactor.py +++ b/tests/unit/io/test_twistedreactor.py @@ -99,14 +99,23 @@ def setUp(self): self.reactor_cft_patcher = patch( 'twisted.internet.reactor.callFromThread') self.reactor_run_patcher = patch('twisted.internet.reactor.run') + # Patch reactor.running to False so maybe_start() always enters + # the branch that spawns the reactor thread. Without this, leaked + # reactor state from prior tests can cause reactor.running to be + # True, making maybe_start() a no-op and the reactor.run mock + # never called — leading to a flaky test_connection_initialization. + self.reactor_running_patcher = patch( + 'twisted.internet.reactor.running', new=False) self.mock_reactor_cft = self.reactor_cft_patcher.start() self.mock_reactor_run = self.reactor_run_patcher.start() + self.reactor_running_patcher.start() self.obj_ut = twistedreactor.TwistedConnection(DefaultEndPoint('1.2.3.4'), cql_version='3.0.1') def tearDown(self): self.reactor_cft_patcher.stop() self.reactor_run_patcher.stop() + self.reactor_running_patcher.stop() def test_connection_initialization(self): """ diff --git a/tests/unit/test_cache_apply_parameters.py b/tests/unit/test_cache_apply_parameters.py new file mode 100644 index 0000000000..58f41f6acf --- /dev/null +++ b/tests/unit/test_cache_apply_parameters.py @@ -0,0 +1,90 @@ +""" +Unit tests for apply_parameters caching in _CassandraType. +""" +import unittest +from cassandra.cqltypes import ( + MapType, SetType, ListType, TupleType, + Int32Type, UTF8Type, FloatType, DoubleType, BooleanType, + _CassandraType, +) + + +class TestApplyParametersCache(unittest.TestCase): + + def setUp(self): + _CassandraType._apply_parameters_cache.clear() + + def test_cache_returns_same_object(self): + """Repeated apply_parameters calls return the exact same class object.""" + result1 = MapType.apply_parameters([UTF8Type, Int32Type]) + result2 = MapType.apply_parameters([UTF8Type, Int32Type]) + self.assertIs(result1, result2) + + def test_cache_different_subtypes_different_results(self): + """Different subtype combinations produce different cached classes.""" + r1 = MapType.apply_parameters([UTF8Type, Int32Type]) + r2 = MapType.apply_parameters([Int32Type, UTF8Type]) + self.assertIsNot(r1, r2) + + def test_cache_different_base_types(self): + """Different base types with same subtypes produce different classes.""" + r1 = SetType.apply_parameters([Int32Type]) + r2 = ListType.apply_parameters([Int32Type]) + self.assertIsNot(r1, r2) + + def test_cached_type_has_correct_subtypes(self): + """Cached types preserve their subtype information.""" + result = MapType.apply_parameters([UTF8Type, FloatType]) + self.assertEqual(result.subtypes, (UTF8Type, FloatType)) + # Call again, verify cache hit still has correct subtypes + result2 = MapType.apply_parameters([UTF8Type, FloatType]) + self.assertEqual(result2.subtypes, (UTF8Type, FloatType)) + + def test_cached_type_has_correct_cassname(self): + """Cached types preserve their cassname.""" + result = SetType.apply_parameters([DoubleType]) + self.assertEqual(result.cassname, SetType.cassname) + + def test_cached_type_with_names(self): + """Caching works correctly with named parameters (UDT-style).""" + r1 = TupleType.apply_parameters([Int32Type, UTF8Type], names=['id', 'name']) + r2 = TupleType.apply_parameters([Int32Type, UTF8Type], names=['id', 'name']) + self.assertIs(r1, r2) + + def test_different_names_different_cache_entries(self): + """Different names produce different cached classes.""" + r1 = TupleType.apply_parameters([Int32Type, UTF8Type], names=['id', 'name']) + r2 = TupleType.apply_parameters([Int32Type, UTF8Type], names=['key', 'value']) + self.assertIsNot(r1, r2) + + def test_names_none_vs_no_names(self): + """Passing names=None and not passing names use the same cache entry.""" + r1 = MapType.apply_parameters([UTF8Type, Int32Type], names=None) + r2 = MapType.apply_parameters([UTF8Type, Int32Type]) + self.assertIs(r1, r2) + + def test_tuple_subtypes_accepted(self): + """Both list and tuple subtypes produce the same cached result.""" + r1 = MapType.apply_parameters([UTF8Type, Int32Type]) + r2 = MapType.apply_parameters((UTF8Type, Int32Type)) + self.assertIs(r1, r2) + + def test_cache_populated(self): + """The cache dict is populated after apply_parameters calls.""" + _CassandraType._apply_parameters_cache.clear() + MapType.apply_parameters([UTF8Type, Int32Type]) + self.assertGreater(len(_CassandraType._apply_parameters_cache), 0) + + def test_cache_clear_forces_new_creation(self): + """Clearing the cache forces new type creation.""" + r1 = MapType.apply_parameters([UTF8Type, Int32Type]) + _CassandraType._apply_parameters_cache.clear() + r2 = MapType.apply_parameters([UTF8Type, Int32Type]) + # After clearing, we get a new class (different object identity) + self.assertIsNot(r1, r2) + # But they should be functionally equivalent + self.assertEqual(r1.subtypes, r2.subtypes) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/unit/test_client_routes.py b/tests/unit/test_client_routes.py new file mode 100644 index 0000000000..0aa82fc76a --- /dev/null +++ b/tests/unit/test_client_routes.py @@ -0,0 +1,482 @@ +# Copyright 2026 ScyllaDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import socket +import ssl +import unittest +import uuid +from unittest.mock import Mock, patch + +from cassandra.client_routes import ( + ClientRouteProxy, + ClientRoutesChangeType, + ClientRoutesConfig, + _RouteStore, + _Route, + _ClientRoutesHandler +) +from cassandra.connection import ClientRoutesEndPoint, ClientRoutesEndPointFactory +from cassandra.cluster import Cluster + + +class TestClientRouteProxy(unittest.TestCase): + + def test_endpoint_none_connection_id(self): + with self.assertRaises(ValueError): + ClientRouteProxy(None) + + +class TestClientRoutesConfig(unittest.TestCase): + + def test_config_with_proxies(self): + ep1 = ClientRouteProxy(str(uuid.uuid4()), "10.0.0.1") + ep2 = ClientRouteProxy(str(uuid.uuid4()), "10.0.0.2") + config = ClientRoutesConfig([ep1, ep2]) + self.assertEqual(len(config.proxies), 2) + + def test_config_empty_proxies(self): + with self.assertRaises(ValueError): + ClientRoutesConfig([]) + + def test_config_invalid_proxy_type(self): + with self.assertRaises(TypeError): + ClientRoutesConfig(["not-a-proxy"]) + + + +class TestRouteStore(unittest.TestCase): + + def test_get_by_host_id(self): + routes = _RouteStore() + host_id = uuid.uuid4() + route = _Route( + connection_id=str(uuid.uuid4()), + host_id=host_id, + address="example.com", + port=9042, + ) + + routes.update([route]) + + retrieved = routes.get_by_host_id(host_id) + self.assertEqual(retrieved.host_id, host_id) + self.assertEqual(retrieved.address, "example.com") + + def test_merge_routes(self): + routes = _RouteStore() + host_id1 = uuid.uuid4() + host_id2 = uuid.uuid4() + + route1 = _Route( + connection_id=str(uuid.uuid4()), host_id=host_id1, + address="host1.com", port=9042, + ) + + route2 = _Route( + connection_id=str(uuid.uuid4()), host_id=host_id2, + address="host2.com", port=9042, + ) + + routes.update([route1]) + routes.merge([route2], affected_host_ids={host_id2}) + + self.assertIsNotNone(routes.get_by_host_id(host_id1)) + self.assertIsNotNone(routes.get_by_host_id(host_id2)) + + def test_merge_deletes_affected_host_with_no_new_route(self): + """When an affected host_id has no corresponding new route, it should be removed.""" + store = _RouteStore() + host_id1 = uuid.uuid4() + host_id2 = uuid.uuid4() + conn_id = str(uuid.uuid4()) + + store.update([ + _Route(connection_id=conn_id, host_id=host_id1, address="a.com", port=9042), + _Route(connection_id=conn_id, host_id=host_id2, address="b.com", port=9042), + ]) + self.assertIsNotNone(store.get_by_host_id(host_id1)) + self.assertIsNotNone(store.get_by_host_id(host_id2)) + + # Merge with host_id2 affected but no new route for it → deletion + store.merge([], affected_host_ids={host_id2}) + + self.assertIsNotNone(store.get_by_host_id(host_id1)) + self.assertIsNone(store.get_by_host_id(host_id2)) + + def test_select_preferred_routes_keeps_existing_connection_id(self): + """When multiple connection_ids provide routes for the same host_id, + the one already in use should be preferred.""" + store = _RouteStore() + host_id = uuid.uuid4() + conn_a = "conn-a" + conn_b = "conn-b" + + # Populate store with conn_a for host_id + store.update([_Route(connection_id=conn_a, host_id=host_id, address="a.com", port=9042)]) + self.assertEqual(store.get_by_host_id(host_id).connection_id, conn_a) + + # Update with both conn_a and conn_b for the same host_id + store.update([ + _Route(connection_id=conn_b, host_id=host_id, address="b.com", port=9042), + _Route(connection_id=conn_a, host_id=host_id, address="a-new.com", port=9042), + ]) + # conn_a should be preferred since it was already in use + result = store.get_by_host_id(host_id) + self.assertEqual(result.connection_id, conn_a) + self.assertEqual(result.address, "a-new.com") + + def test_select_preferred_routes_falls_back_when_existing_gone(self): + """When the existing connection_id is no longer among candidates, + the first candidate should be selected.""" + store = _RouteStore() + host_id = uuid.uuid4() + + store.update([_Route(connection_id="old-conn", host_id=host_id, address="old.com", port=9042)]) + + # Update only has new connection_ids + store.update([ + _Route(connection_id="new-a", host_id=host_id, address="a.com", port=9042), + _Route(connection_id="new-b", host_id=host_id, address="b.com", port=9042), + ]) + result = store.get_by_host_id(host_id) + self.assertEqual(result.connection_id, "new-a") + + +class TestClientRoutesHandler(unittest.TestCase): + + def setUp(self): + self.conn_id = uuid.uuid4() + self.proxy = ClientRouteProxy(str(self.conn_id), "10.0.0.1") + self.config = ClientRoutesConfig([self.proxy]) + + def test_handler_initialization(self): + handler = _ClientRoutesHandler(self.config, ssl_enabled=False) + self.assertIsNotNone(handler) + self.assertEqual(handler.ssl_enabled, False) + + @patch.object(_ClientRoutesHandler, '_query_all_routes_for_connections') + def test_initialize(self, mock_query): + host_id = uuid.uuid4() + mock_query.return_value = [ + _Route( + connection_id=self.conn_id, + host_id=host_id, + address="node1.example.com", + port=9042, + ) + ] + + handler = _ClientRoutesHandler(self.config) + mock_conn = Mock() + + handler.initialize(mock_conn, timeout=5.0) + + mock_query.assert_called_once() + route = handler._routes.get_by_host_id(host_id) + self.assertIsNotNone(route) + self.assertEqual(route.address, "node1.example.com") + + @patch.object(_ClientRoutesHandler, '_query_routes_for_change_event') + def test_handle_change_filters_by_configured_connection_ids(self, mock_query): + """Events with unrelated connection_ids should be ignored.""" + handler = _ClientRoutesHandler(self.config) + mock_conn = Mock() + host_id = str(uuid.uuid4()) + + # Event with a connection_id NOT in our config → should return early + handler.handle_client_routes_change( + mock_conn, 5.0, + ClientRoutesChangeType.UPDATE_NODES, + connection_ids=["unrelated-conn-id"], + host_ids=[host_id], + ) + mock_query.assert_not_called() + + @patch.object(_ClientRoutesHandler, '_query_routes_for_change_event') + def test_handle_change_merges_when_host_ids_present(self, mock_query): + """When host_ids are provided, routes should be merged (not full replace).""" + handler = _ClientRoutesHandler(self.config) + mock_conn = Mock() + + existing_host = uuid.uuid4() + new_host = uuid.uuid4() + conn_id = str(self.conn_id) + + # Pre-populate a route + handler._routes.update([ + _Route(connection_id=conn_id, host_id=existing_host, address="old.com", port=9042), + ]) + + mock_query.return_value = [ + _Route(connection_id=conn_id, host_id=new_host, address="new.com", port=9042), + ] + + handler.handle_client_routes_change( + mock_conn, 5.0, + ClientRoutesChangeType.UPDATE_NODES, + connection_ids=[conn_id], + host_ids=[str(new_host)], + ) + + # Existing route should still be there (merge, not replace) + self.assertIsNotNone(handler._routes.get_by_host_id(existing_host)) + self.assertIsNotNone(handler._routes.get_by_host_id(new_host)) + + @patch.object(_ClientRoutesHandler, '_query_all_routes_for_connections') + def test_handle_change_updates_when_no_host_ids(self, mock_query): + """When no host_ids are provided, routes should be fully replaced.""" + handler = _ClientRoutesHandler(self.config) + mock_conn = Mock() + conn_id = str(self.conn_id) + + old_host = uuid.uuid4() + handler._routes.update([ + _Route(connection_id=conn_id, host_id=old_host, address="old.com", port=9042), + ]) + + new_host = uuid.uuid4() + mock_query.return_value = [ + _Route(connection_id=conn_id, host_id=new_host, address="new.com", port=9042), + ] + + handler.handle_client_routes_change( + mock_conn, 5.0, + ClientRoutesChangeType.UPDATE_NODES, + connection_ids=None, + host_ids=None, + ) + + # Full replace: old_host gone, new_host present + self.assertIsNone(handler._routes.get_by_host_id(old_host)) + self.assertIsNotNone(handler._routes.get_by_host_id(new_host)) + + @patch.object(_ClientRoutesHandler, '_query_routes_for_change_event') + def test_handle_change_propagates_query_failure(self, mock_query): + """If _query_routes raises, handle_client_routes_change should propagate.""" + handler = _ClientRoutesHandler(self.config) + mock_conn = Mock() + mock_query.side_effect = Exception("network error") + + conn_id = self.proxy.connection_id + host_id = str(uuid.uuid4()) + with self.assertRaises(Exception) as cm: + handler.handle_client_routes_change( + mock_conn, 5.0, + ClientRoutesChangeType.UPDATE_NODES, + connection_ids=[conn_id], + host_ids=[host_id], + ) + self.assertIn("network error", str(cm.exception)) + + @patch.object(_ClientRoutesHandler, '_query_all_routes_for_connections') + def test_initialize_propagates_exception_on_failure(self, mock_query): + """initialize should propagate exceptions to caller.""" + handler = _ClientRoutesHandler(self.config) + mock_conn = Mock() + mock_query.side_effect = Exception("query failed") + + with self.assertRaises(Exception) as ctx: + handler.initialize(mock_conn, 5.0) + self.assertIn("query failed", str(ctx.exception)) + self.assertEqual(mock_query.call_count, 1) + + @patch.object(_ClientRoutesHandler, '_query_all_routes_for_connections') + def test_initialize_keeps_old_routes_on_failure(self, mock_query): + """On failure, existing routes must be preserved (critical for PL clusters).""" + handler = _ClientRoutesHandler(self.config) + mock_conn = Mock() + host_id = uuid.uuid4() + + # Pre-populate a route + handler._routes.update([ + _Route(connection_id=str(self.conn_id), host_id=host_id, address="old.com", port=9042), + ]) + + mock_query.side_effect = Exception("query failed") + with self.assertRaises(Exception): + handler.initialize(mock_conn, 5.0) + + # Old route must still be there + self.assertIsNotNone(handler._routes.get_by_host_id(host_id)) + + @patch.object(_ClientRoutesHandler, '_query_all_routes_for_connections') + def test_initialize_updates_routes_on_success(self, mock_query): + """initialize should update routes on success.""" + handler = _ClientRoutesHandler(self.config) + mock_conn = Mock() + host_id = uuid.uuid4() + + mock_query.return_value = [ + _Route(connection_id=str(self.conn_id), host_id=host_id, address="new.com", port=9042), + ] + + handler.initialize(mock_conn, 5.0) + + self.assertEqual(mock_query.call_count, 1) + route = handler._routes.get_by_host_id(host_id) + self.assertIsNotNone(route) + self.assertEqual(route.address, "new.com") + +class TestClientRoutesEndPoint(unittest.TestCase): + + def setUp(self): + self.conn_id = uuid.uuid4() + self.proxy = ClientRouteProxy(str(self.conn_id), "10.0.0.1") + self.config = ClientRoutesConfig([self.proxy]) + self.handler = _ClientRoutesHandler(self.config, ssl_enabled=False) + + def test_resolve_falls_back_when_no_mapping(self): + """resolve() should return original address/port when no route mapping exists.""" + host_id = uuid.uuid4() + ep = ClientRoutesEndPoint( + host_id=host_id, + handler=self.handler, + original_address="10.0.0.1", + original_port=9042, + ) + self.assertEqual(ep.resolve(), ("10.0.0.1", 9042)) + + @patch('cassandra.client_routes.socket.getaddrinfo', + return_value=[(socket.AF_INET, socket.SOCK_STREAM, 0, '', ("192.168.1.100", 9042))]) + def test_resolve_returns_address_when_route_exists(self, _mock_getaddrinfo): + """resolve() should return the DNS-resolved address and port when a route exists.""" + host_id = uuid.uuid4() + self.handler._routes.update([ + _Route(connection_id=str(self.conn_id), host_id=host_id, + address="nlb.example.com", port=9042), + ]) + ep = ClientRoutesEndPoint( + host_id=host_id, + handler=self.handler, + original_address="10.0.0.1", + original_port=9042, + ) + self.assertEqual(ep.resolve(), ("192.168.1.100", 9042)) + _mock_getaddrinfo.assert_called_once_with( + "nlb.example.com", 9042, socket.AF_UNSPEC, socket.SOCK_STREAM) + + @patch('cassandra.client_routes.socket.getaddrinfo', + side_effect=socket.gaierror("DNS resolution failed")) + def test_resolve_host_dns_failure_raises(self, _mock_getaddrinfo): + """resolve_host should propagate socket.gaierror on DNS failure.""" + host_id = uuid.uuid4() + self.handler._routes.update([ + _Route(connection_id=str(self.conn_id), host_id=host_id, + address="nonexistent.example.com", port=9042), + ]) + with self.assertRaises(socket.gaierror): + self.handler.resolve_host(host_id) + + def test_resolve_host_missing_port_raises(self): + """resolve_host should raise ValueError when route has no port.""" + host_id = uuid.uuid4() + self.handler._routes.update([ + _Route(connection_id=str(self.conn_id), host_id=host_id, + address="host.com", port=0), + ]) + with self.assertRaises(ValueError): + self.handler.resolve_host(host_id) + + +class TestClientRoutesEndPointFactory(unittest.TestCase): + + def setUp(self): + self.conn_id = uuid.uuid4() + proxy = ClientRouteProxy(str(self.conn_id), "10.0.0.1") + self.config = ClientRoutesConfig([proxy]) + self.handler = _ClientRoutesHandler(self.config, ssl_enabled=False) + self.factory = ClientRoutesEndPointFactory(self.handler, default_port=9042) + + def test_create_from_row(self): + """Factory should create a ClientRoutesEndPoint from a peers row.""" + host_id = uuid.uuid4() + row = { + "host_id": host_id, + "rpc_address": "10.0.0.5", + "native_transport_port": 9042, + "peer": "10.0.0.5", + } + ep = self.factory.create(row) + self.assertIsInstance(ep, ClientRoutesEndPoint) + self.assertEqual(ep.host_id, host_id) + self.assertEqual(ep.address, "10.0.0.5") + + def test_create_missing_host_id_raises(self): + """Factory should raise ValueError when row has no host_id.""" + row = {"rpc_address": "10.0.0.5", "native_transport_port": 9042} + with self.assertRaises(ValueError): + self.factory.create(row) + +class TestClientRoutesSSLValidation(unittest.TestCase): + + def test_check_hostname_with_ssl_context_raises(self): + """Cluster should reject check_hostname=True with client_routes_config.""" + ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + self.assertTrue(ssl_ctx.check_hostname) + + config = ClientRoutesConfig( + proxies=[ClientRouteProxy(str(uuid.uuid4()), "10.0.0.1")] + ) + with self.assertRaises(ValueError) as cm: + Cluster( + contact_points=["10.0.0.1"], + ssl_context=ssl_ctx, + client_routes_config=config, + ) + self.assertIn("check_hostname", str(cm.exception)) + + def test_check_hostname_with_ssl_options_raises(self): + """Cluster should reject check_hostname=True in ssl_options with client_routes_config.""" + config = ClientRoutesConfig( + proxies=[ClientRouteProxy(str(uuid.uuid4()), "10.0.0.1")] + ) + with self.assertRaises(ValueError) as cm: + Cluster( + contact_points=["10.0.0.1"], + ssl_options={'check_hostname': True}, + client_routes_config=config, + ) + self.assertIn("check_hostname", str(cm.exception)) + + def test_disabled_check_hostname_with_client_routes_ok(self): + """Cluster should allow check_hostname=False with client_routes_config.""" + ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ssl_ctx.check_hostname = False + + config = ClientRoutesConfig( + proxies=[ClientRouteProxy(str(uuid.uuid4()), "10.0.0.1")] + ) + # Should not raise + cluster = Cluster( + contact_points=["10.0.0.1"], + ssl_context=ssl_ctx, + client_routes_config=config, + ) + cluster.shutdown() + + def test_no_ssl_with_client_routes_ok(self): + """Cluster should allow client_routes_config without SSL.""" + config = ClientRoutesConfig( + proxies=[ClientRouteProxy(str(uuid.uuid4()), "10.0.0.1")] + ) + # Should not raise + cluster = Cluster( + contact_points=["10.0.0.1"], + client_routes_config=config, + ) + cluster.shutdown() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index 49208ac53e..4942fd4d69 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -91,7 +91,10 @@ class ClusterTest(unittest.TestCase): def test_tuple_for_contact_points(self): cluster = Cluster(contact_points=[('localhost', 9045), ('127.0.0.2', 9046), '127.0.0.3'], port=9999) - localhost_addr = set([addr[0] for addr in [t for (_,_,_,_,t) in socket.getaddrinfo("localhost",80)]]) + # Refactored for clarity + addr_info = socket.getaddrinfo("localhost", 80) + sockaddr_tuples = [info[4] for info in addr_info] # info[4] is sockaddr + localhost_addr = set([sockaddr[0] for sockaddr in sockaddr_tuples]) for cp in cluster.endpoints_resolved: if cp.address in localhost_addr: assert cp.port == 9045 @@ -108,7 +111,7 @@ def test_invalid_contact_point_types(self): Cluster(contact_points="not a sequence", protocol_version=4, connect_timeout=1) def test_port_str(self): - """Check port passed as tring is converted and checked properly""" + """Check port passed as string is converted and checked properly""" cluster = Cluster(contact_points=['127.0.0.1'], port='1111') for cp in cluster.endpoints_resolved: if cp.address in ('::1', '127.0.0.1'): @@ -182,7 +185,7 @@ def test_event_delay_timing(self, *_): """ sched = _Scheduler(None) sched.schedule(0, lambda: None) - sched.schedule(0, lambda: None) # pre-473: "TypeError: unorderable types: function() < function()"t + sched.schedule(0, lambda: None) # pre-473: "TypeError: unorderable types: function() < function()" class SessionTest(unittest.TestCase): @@ -292,8 +295,8 @@ def test_default_exec_parameters(self): assert cluster.profile_manager.default.request_timeout == 10.0 assert session.default_consistency_level == ConsistencyLevel.LOCAL_ONE assert cluster.profile_manager.default.consistency_level == ConsistencyLevel.LOCAL_ONE - assert session.default_serial_consistency_level == None - assert cluster.profile_manager.default.serial_consistency_level == None + assert session.default_serial_consistency_level is None + assert cluster.profile_manager.default.serial_consistency_level is None assert session.row_factory == named_tuple_factory assert cluster.profile_manager.default.row_factory == named_tuple_factory diff --git a/tests/unit/test_control_connection.py b/tests/unit/test_control_connection.py index d759e12332..037d4a8888 100644 --- a/tests/unit/test_control_connection.py +++ b/tests/unit/test_control_connection.py @@ -287,6 +287,20 @@ def test_wait_for_schema_agreement_rpc_lookup(self): assert not self.control_connection.wait_for_schema_agreement() assert self.time.clock >= self.cluster.max_schema_agreement_wait + + def test_wait_for_schema_agreement_none_timeout(self): + """ + When control_connection_timeout is None, wait_for_schema_agreement + should not raise a TypeError on the min() call. + """ + cc = ControlConnection(self.cluster, timeout=None, + schema_event_refresh_window=0, + topology_event_refresh_window=0, + status_event_refresh_window=0) + cc._connection = self.connection + cc._time = self.time + assert cc.wait_for_schema_agreement() + def test_refresh_nodes_and_tokens(self): self.control_connection.refresh_node_list_and_token_map() meta = self.cluster.metadata diff --git a/tests/unit/test_types.py b/tests/unit/test_types.py index 7a8c584f75..11aab2748d 100644 --- a/tests/unit/test_types.py +++ b/tests/unit/test_types.py @@ -120,8 +120,9 @@ def test_lookup_casstype(self): assert str(lookup_casstype('unknown')) == str(cassandra.cqltypes.mkUnrecognizedType('unknown')) - with pytest.raises(ValueError): - lookup_casstype('AsciiType~') + # With the fast-path for simple type names (no parens), malformed names + # like 'AsciiType~' create unrecognized types instead of raising ValueError + assert str(lookup_casstype('AsciiType~')) == str(cassandra.cqltypes.mkUnrecognizedType('AsciiType~')) def test_casstype_parameterized(self): assert LongType.cass_parameterized_type_with(()) == 'LongType'