diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 0c4aa63669..3ae00a7ee8 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,25 @@ +3.29.9 +====== +March 18, 2026 + +Features +-------- +* Add Private Link support via client routes handler +* Add optional query_params parameter to QueryMessage + +Bug Fixes +--------- +* Fix segmentation fault in libev prepare_callback during shutdown +* Add null checks to io_callback and timer_callback in libev wrapper +* Fix RecursionError in execute_concurrent on synchronous errbacks +* Fix floating-point precision loss for timestamps far from epoch + +Others +------ +* Cache parsed tablet routing type in ResponseFuture +* Remove deprecated setup_requires in favor of PEP 517 build-system.requires +* Update dependency hatchling to v1.29.0 + 3.29.8 ====== February 09, 2026 diff --git a/benchmarks/test_parse_desc_cache_benchmark.py b/benchmarks/test_parse_desc_cache_benchmark.py new file mode 100644 index 0000000000..d45e821be1 --- /dev/null +++ b/benchmarks/test_parse_desc_cache_benchmark.py @@ -0,0 +1,549 @@ +# Copyright ScyllaDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Benchmarks for ParseDesc construction with and without caching. + +The ParseDesc is built on every response in recv_results_rows(). For prepared +statements the column_metadata list is the same object every time, so caching +the ParseDesc (keyed by id(column_metadata)) avoids repeated list +comprehensions, ColDesc construction, and make_deserializers() calls. + +There are two benchmark tiers: + +1. **Integration benchmarks** (test_integration_*): Exercise the actual Cython + _get_or_build_parse_desc function through the real recv_results_rows closure + returned by make_recv_results_rows(). These use a mock ResultMessage with a + binary buffer simulating the prepared statement path (NO_METADATA_FLAG set). + +2. **Isolated benchmarks** (test_parse_desc_*, test_full_pipeline_*): Measure + ParseDesc construction and row parsing using a pure-Python cache replica. + Useful for understanding the breakdown of costs but do not exercise the + actual Cython cache code path. + +Run with: + pytest benchmarks/test_parse_desc_cache_benchmark.py -v +""" + +import io +import struct +import pytest + +# Skip the entire module when pytest-benchmark is not installed. +# The benchmark fixture is provided by the pytest-benchmark plugin which +# is not in the project's dev dependencies. This guard prevents +# "fixture 'benchmark' not found" errors when running bare `pytest` from +# the repo root. +pytest.importorskip("pytest_benchmark") + +from cassandra import cqltypes +from cassandra.policies import ColDesc +from cassandra.parsing import ParseDesc +from cassandra.deserializers import make_deserializers +from cassandra.bytesio import BytesIOReader +from cassandra.obj_parser import ListParser +from cassandra.row_parser import ( + clear_parse_desc_cache, + get_parse_desc_cache_size, + make_recv_results_rows, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _build_column_metadata(ncols, cql_type=cqltypes.UTF8Type): + """Build a column_metadata list like the driver produces.""" + return [("ks", "tbl", "col_%d" % i, cql_type) for i in range(ncols)] + + +def _build_uncached_parse_desc( + column_metadata, column_encryption_policy, protocol_version +): + """Original uncached ParseDesc construction (baseline).""" + column_names = [md[2] for md in column_metadata] + column_types = [md[3] for md in column_metadata] + desc = ParseDesc( + column_names, + column_types, + column_encryption_policy, + [ColDesc(md[0], md[1], md[2]) for md in column_metadata], + make_deserializers(column_types), + protocol_version, + ) + return desc, column_names, column_types + + +def _build_binary_rows(nrows, ncols, col_value=b"hello world"): + """ + Build a binary buffer matching the Cassandra row format: + int32(rowcount) + for each row: + for each col: int32(len) + bytes + """ + parts = [struct.pack(">i", nrows)] + col_cell = struct.pack(">i", len(col_value)) + col_value + row_data = col_cell * ncols + for _ in range(nrows): + parts.append(row_data) + return b"".join(parts) + + +# --------------------------------------------------------------------------- +# Integration helpers — exercise the actual Cython recv_results_rows +# --------------------------------------------------------------------------- + +# NO_METADATA_FLAG as defined in ResultMessage +_NO_METADATA_FLAG = 0x0004 + + +class _MockResultMessage: + """ + Minimal mock of ResultMessage for the prepared-statement path. + + When NO_METADATA_FLAG is set in the binary stream, recv_results_metadata + reads just the flags + colcount and returns, leaving column_metadata as + None. The closure then falls through to result_metadata (the prepared + statement's stored metadata). + """ + + column_metadata = None + column_names = None + column_types = None + parsed_rows = None + paging_state = None + continuous_paging_seq = None + continuous_paging_last = None + result_metadata_id = None + + def recv_results_metadata(self, f, user_type_map): + """Simulate the prepared-statement path (NO_METADATA_FLAG is set).""" + # Read flags + colcount just like the real recv_results_metadata does + _flags = struct.unpack(">i", f.read(4))[0] + _colcount = struct.unpack(">i", f.read(4))[0] + # NO_METADATA_FLAG is set, so return immediately — column_metadata stays None + + +def _build_integration_binary_buf(nrows, ncols, col_value=b"hello world"): + """ + Build a full binary buffer for the integration benchmark. + + Format for the prepared-statement path: + int32(flags=NO_METADATA_FLAG) -- read by recv_results_metadata + int32(colcount) -- read by recv_results_metadata + int32(rowcount) -- read by BytesIOReader in parse_rows + for each row: + for each col: int32(len) + bytes + """ + parts = [] + parts.append(struct.pack(">i", _NO_METADATA_FLAG)) # flags + parts.append(struct.pack(">i", ncols)) # colcount + parts.append(struct.pack(">i", nrows)) # rowcount + col_cell = struct.pack(">i", len(col_value)) + col_value + row_data = col_cell * ncols + for _ in range(nrows): + parts.append(row_data) + return b"".join(parts) + + +# The actual Cython recv_results_rows closure — this calls _get_or_build_parse_desc internally +_cython_recv_results_rows = make_recv_results_rows(ListParser()) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _clear_cython_cache(): + """Ensure the Cython ParseDesc cache is empty before and after each test.""" + clear_parse_desc_cache() + yield + clear_parse_desc_cache() + + +@pytest.fixture() +def py_cache(): + """Provide a fresh pure-Python cache dict for each test that needs it.""" + return {} + + +# --------------------------------------------------------------------------- +# Pure-Python cache replica (reference implementation) +# Useful for understanding cost breakdown but does NOT exercise the actual +# Cython cdef inline _get_or_build_parse_desc function. +# --------------------------------------------------------------------------- + +_PY_CACHE_MAX_SIZE = 256 + + +def _cached_parse_desc_py( + column_metadata, column_encryption_policy, protocol_version, cache +): + """Pure-Python replica of the Cython cache for reference comparison.""" + cache_key = id(column_metadata) + cached = cache.get(cache_key) + if cached is not None: + if ( + cached[0] is column_metadata + and cached[1] is column_encryption_policy + and cached[2] == protocol_version + ): + return cached[3], cached[4], cached[5] + + column_names = [md[2] for md in column_metadata] + column_types = [md[3] for md in column_metadata] + desc = ParseDesc( + column_names, + column_types, + column_encryption_policy, + [ColDesc(md[0], md[1], md[2]) for md in column_metadata], + make_deserializers(column_types), + protocol_version, + ) + + if len(cache) >= _PY_CACHE_MAX_SIZE: + cache.clear() + + cache[cache_key] = ( + column_metadata, + column_encryption_policy, + protocol_version, + desc, + column_names, + column_types, + ) + return desc, column_names, column_types + + +# --------------------------------------------------------------------------- +# Integration benchmarks: actual Cython recv_results_rows +# These exercise the real _get_or_build_parse_desc cdef inline function. +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("nrows,ncols", [(1, 10), (100, 5), (1000, 5)]) +def test_integration_cython_cached(benchmark, nrows, ncols): + """Integration: Cython recv_results_rows with ParseDesc cache hit (prepared stmt).""" + col_meta = _build_column_metadata(ncols) + binary_buf = _build_integration_binary_buf(nrows, ncols) + + # Warm the Cython cache with the same col_meta object + warmup = _MockResultMessage() + _cython_recv_results_rows(warmup, io.BytesIO(binary_buf), 4, {}, col_meta, None) + + def run(): + msg = _MockResultMessage() + _cython_recv_results_rows(msg, io.BytesIO(binary_buf), 4, {}, col_meta, None) + return msg.parsed_rows + + rows = benchmark(run) + assert len(rows) == nrows + assert len(rows[0]) == ncols + + +@pytest.mark.parametrize("nrows,ncols", [(1, 10), (100, 5), (1000, 5)]) +def test_integration_cython_uncached(benchmark, nrows, ncols): + """Integration: Cython recv_results_rows with cache miss (fresh metadata each call).""" + binary_buf = _build_integration_binary_buf(nrows, ncols) + + def run(): + # Fresh metadata list each call forces a cache miss + fresh_meta = _build_column_metadata(ncols) + msg = _MockResultMessage() + _cython_recv_results_rows(msg, io.BytesIO(binary_buf), 4, {}, fresh_meta, None) + return msg.parsed_rows + + rows = benchmark(run) + assert len(rows) == nrows + assert len(rows[0]) == ncols + + +# --------------------------------------------------------------------------- +# Integration correctness tests: verify the actual Cython cache behavior +# --------------------------------------------------------------------------- + + +def test_integration_cython_cache_hit(): + """Cython cache returns same column_names/column_types on repeated calls.""" + col_meta = _build_column_metadata(5) + binary_buf = _build_integration_binary_buf(1, 5) + + msg1 = _MockResultMessage() + _cython_recv_results_rows(msg1, io.BytesIO(binary_buf), 4, {}, col_meta, None) + + msg2 = _MockResultMessage() + _cython_recv_results_rows(msg2, io.BytesIO(binary_buf), 4, {}, col_meta, None) + + # Same col_meta object -> cache hit -> same column_names/types objects + assert msg1.column_names is msg2.column_names + assert msg1.column_types is msg2.column_types + + +def test_integration_cython_cache_miss_different_metadata(): + """Different metadata list objects produce cache misses.""" + binary_buf = _build_integration_binary_buf(1, 5) + + col_meta_a = _build_column_metadata(5) + col_meta_b = _build_column_metadata(5) # same shape but different list object + + msg_a = _MockResultMessage() + _cython_recv_results_rows(msg_a, io.BytesIO(binary_buf), 4, {}, col_meta_a, None) + + msg_b = _MockResultMessage() + _cython_recv_results_rows(msg_b, io.BytesIO(binary_buf), 4, {}, col_meta_b, None) + + # Different list objects -> different id() -> cache miss + assert msg_a.column_names is not msg_b.column_names + # But values are equivalent + assert msg_a.column_names == msg_b.column_names + + +def test_integration_cython_cache_invalidation_protocol_version(): + """Changed protocol_version invalidates the Cython cache entry.""" + col_meta = _build_column_metadata(5) + binary_buf = _build_integration_binary_buf(1, 5) + + msg_v4 = _MockResultMessage() + _cython_recv_results_rows(msg_v4, io.BytesIO(binary_buf), 4, {}, col_meta, None) + + msg_v5 = _MockResultMessage() + _cython_recv_results_rows(msg_v5, io.BytesIO(binary_buf), 5, {}, col_meta, None) + + # Same col_meta but different protocol_version -> cache miss -> different objects + assert msg_v4.column_names is not msg_v5.column_names + + +def test_integration_cython_clear_cache(): + """clear_parse_desc_cache() invalidates cached entries.""" + col_meta = _build_column_metadata(5) + binary_buf = _build_integration_binary_buf(1, 5) + + msg1 = _MockResultMessage() + _cython_recv_results_rows(msg1, io.BytesIO(binary_buf), 4, {}, col_meta, None) + + clear_parse_desc_cache() + + msg2 = _MockResultMessage() + _cython_recv_results_rows(msg2, io.BytesIO(binary_buf), 4, {}, col_meta, None) + + # After cache clear, new ParseDesc is built -> different column_names object + assert msg1.column_names is not msg2.column_names + assert msg1.column_names == msg2.column_names + + +def test_integration_cython_parsed_rows_correctness(): + """Integration: verify parsed row data is correct through the Cython path.""" + ncols = 5 + nrows = 3 + col_meta = _build_column_metadata(ncols) + binary_buf = _build_integration_binary_buf(nrows, ncols, col_value=b"test_val") + + msg = _MockResultMessage() + _cython_recv_results_rows(msg, io.BytesIO(binary_buf), 4, {}, col_meta, None) + + assert len(msg.parsed_rows) == nrows + for row in msg.parsed_rows: + assert len(row) == ncols + for val in row: + assert val == "test_val" + assert msg.column_names == ["col_%d" % i for i in range(ncols)] + + +def test_integration_cython_cache_bounded_size(): + """Cython cache evicts entries when exceeding max size.""" + # Fill the cache with many distinct metadata lists + binary_buf = _build_integration_binary_buf(1, 5) + meta_lists = [_build_column_metadata(5) for _ in range(300)] + + for meta in meta_lists: + msg = _MockResultMessage() + _cython_recv_results_rows(msg, io.BytesIO(binary_buf), 4, {}, meta, None) + + # Cache should have been evicted at least once + cache_size = get_parse_desc_cache_size() + assert cache_size <= 256, ( + "Cache should be bounded to 256 entries, got %d" % cache_size + ) + + +# --------------------------------------------------------------------------- +# Isolated benchmarks: ParseDesc construction (reference, pure-Python) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("ncols", [5, 10, 20]) +def test_parse_desc_build_uncached(benchmark, ncols): + """Reference: build ParseDesc from scratch every time (original code path).""" + col_meta = _build_column_metadata(ncols) + + def run(): + return _build_uncached_parse_desc(col_meta, None, 4) + + result = benchmark(run) + desc, names, types = result + assert len(names) == ncols + assert len(desc.colnames) == ncols + + +@pytest.mark.parametrize("ncols", [5, 10, 20]) +def test_parse_desc_build_cached(benchmark, ncols, py_cache): + """Reference: cached second calls return cached ParseDesc (pure-Python replica).""" + col_meta = _build_column_metadata(ncols) + + # Warm the cache + _cached_parse_desc_py(col_meta, None, 4, py_cache) + + def run(): + return _cached_parse_desc_py(col_meta, None, 4, py_cache) + + result = benchmark(run) + desc, names, types = result + assert len(names) == ncols + assert len(desc.colnames) == ncols + + +# --------------------------------------------------------------------------- +# Isolated benchmarks: Full parse_rows pipeline (reference, pure-Python) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("nrows,ncols", [(1, 10), (100, 5), (1000, 5)]) +def test_full_pipeline_uncached(benchmark, nrows, ncols): + """Reference: build ParseDesc from scratch + parse rows (pure-Python desc).""" + col_meta = _build_column_metadata(ncols) + binary_buf = _build_binary_rows(nrows, ncols) + parser = ListParser() + + def run(): + desc, names, types = _build_uncached_parse_desc(col_meta, None, 4) + reader = BytesIOReader(binary_buf) + return parser.parse_rows(reader, desc) + + rows = benchmark(run) + assert len(rows) == nrows + assert len(rows[0]) == ncols + + +@pytest.mark.parametrize("nrows,ncols", [(1, 10), (100, 5), (1000, 5)]) +def test_full_pipeline_cached(benchmark, nrows, ncols, py_cache): + """Reference: cached ParseDesc + parse rows (pure-Python cache replica).""" + col_meta = _build_column_metadata(ncols) + binary_buf = _build_binary_rows(nrows, ncols) + parser = ListParser() + + # Warm cache + _cached_parse_desc_py(col_meta, None, 4, py_cache) + + def run(): + desc, names, types = _cached_parse_desc_py(col_meta, None, 4, py_cache) + reader = BytesIOReader(binary_buf) + return parser.parse_rows(reader, desc) + + rows = benchmark(run) + assert len(rows) == nrows + assert len(rows[0]) == ncols + + +# --------------------------------------------------------------------------- +# Isolated benchmarks: ParseDesc only (reference, varying column counts) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("ncols", [5, 10, 20, 50]) +def test_parse_desc_only_uncached(benchmark, ncols): + """Reference: isolated ParseDesc construction — uncached.""" + col_meta = _build_column_metadata(ncols) + + benchmark(_build_uncached_parse_desc, col_meta, None, 4) + + +@pytest.mark.parametrize("ncols", [5, 10, 20, 50]) +def test_parse_desc_only_cached(benchmark, ncols, py_cache): + """Reference: isolated ParseDesc construction — cached (pure-Python replica).""" + col_meta = _build_column_metadata(ncols) + _cached_parse_desc_py(col_meta, None, 4, py_cache) # warm + + benchmark(_cached_parse_desc_py, col_meta, None, 4, py_cache) + + +# --------------------------------------------------------------------------- +# Reference correctness tests (pure-Python replica) +# --------------------------------------------------------------------------- + + +def test_cached_same_result_as_uncached(py_cache): + """Verify pure-Python cached path produces identical results to uncached.""" + col_meta = _build_column_metadata(10) + + desc_u, names_u, types_u = _build_uncached_parse_desc(col_meta, None, 4) + desc_c, names_c, types_c = _cached_parse_desc_py(col_meta, None, 4, py_cache) + + assert names_u == names_c + assert types_u == types_c + assert len(desc_u.colnames) == len(desc_c.colnames) + assert desc_u.protocol_version == desc_c.protocol_version + + # Second call should be cache hit and return the same desc object + desc_c2, names_c2, types_c2 = _cached_parse_desc_py(col_meta, None, 4, py_cache) + assert desc_c2 is desc_c # same object from cache + + +def test_cache_invalidation_on_different_metadata(py_cache): + """Different column_metadata list should produce a new ParseDesc (pure-Python).""" + col_meta_a = _build_column_metadata(5) + col_meta_b = _build_column_metadata(10) + + desc_a, _, _ = _cached_parse_desc_py(col_meta_a, None, 4, py_cache) + desc_b, _, _ = _cached_parse_desc_py(col_meta_b, None, 4, py_cache) + + assert desc_a is not desc_b + assert len(desc_a.colnames) == 5 + assert len(desc_b.colnames) == 10 + + +def test_cache_invalidation_on_protocol_version_change(py_cache): + """Changed protocol_version should miss the cache (pure-Python).""" + col_meta = _build_column_metadata(5) + desc_v4, _, _ = _cached_parse_desc_py(col_meta, None, 4, py_cache) + desc_v5, _, _ = _cached_parse_desc_py(col_meta, None, 5, py_cache) + + assert desc_v4 is not desc_v5 + + +def test_clear_parse_desc_cache(): + """Verify the Cython cache can be cleared.""" + clear_parse_desc_cache() # should not raise + + +def test_full_pipeline_correctness(py_cache): + """End-to-end: parse rows with cached ParseDesc produces correct data (pure-Python).""" + ncols = 5 + nrows = 3 + col_meta = _build_column_metadata(ncols) + binary_buf = _build_binary_rows(nrows, ncols, col_value=b"test_val") + parser = ListParser() + + desc, names, types = _cached_parse_desc_py(col_meta, None, 4, py_cache) + reader = BytesIOReader(binary_buf) + rows = parser.parse_rows(reader, desc) + + assert len(rows) == nrows + for row in rows: + assert len(row) == ncols + for val in row: + assert val == "test_val" diff --git a/cassandra/__init__.py b/cassandra/__init__.py index 5567c0b9bd..3ad8fcdfd1 100644 --- a/cassandra/__init__.py +++ b/cassandra/__init__.py @@ -23,7 +23,7 @@ def emit(self, record): logging.getLogger('cassandra').addHandler(NullHandler()) -__version_info__ = (3, 29, 8) +__version_info__ = (3, 29, 9) __version__ = '.'.join(map(str, __version_info__)) diff --git a/cassandra/client_routes.py b/cassandra/client_routes.py new file mode 100644 index 0000000000..80b2477a6d --- /dev/null +++ b/cassandra/client_routes.py @@ -0,0 +1,451 @@ +# Copyright 2026 ScyllaDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Client Routes support for Private Link and similar network configurations. + +This module implements support for dynamic address translation via the +system.client_routes table and CLIENT_ROUTES_CHANGE events. +""" + +from __future__ import absolute_import + +from dataclasses import dataclass +import enum +import logging +import socket +import threading +import uuid +from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple + +from cassandra import ConsistencyLevel +from cassandra.protocol import QueryMessage +from cassandra.query import dict_factory + +if TYPE_CHECKING: + from cassandra.connection import Connection + +log = logging.getLogger(__name__) + + +class ClientRoutesChangeType(enum.Enum): + """ + Types of CLIENT_ROUTES_CHANGE events. + + Currently the protocol defines only UPDATE_NODES. + New variants will be added here if the protocol is extended. + """ + UPDATE_NODES = "UPDATE_NODES" + + +@dataclass +class ClientRouteProxy: + """ + :param connection_id: String identifying the connection (required) + :param connection_addr_override:: Optional string address for initial connection + """ + + connection_id: str + connection_addr_override: Optional[str] = None + + def __post_init__(self): + if self.connection_id is None: + raise ValueError("connection_id is required") + +class ClientRoutesConfig: + """ + Configuration for client routes (Private Link support). + + :param proxies: List of :class:`ClientRouteProxy` objects + (REQUIRED, at least one) + :param advanced_shard_awareness: Whether to enable advanced shard awareness + (default: ``False``) + """ + + proxies: List[ClientRouteProxy] + advanced_shard_awareness: bool + + def __init__(self, proxies: List[ClientRouteProxy], advanced_shard_awareness: bool = False): + """ + :param proxies: List of ClientRouteProxy objects + :param advanced_shard_awareness: Enable advanced shard awareness (default False) + """ + if not proxies: + raise ValueError("At least one proxy must be specified") + + if not isinstance(proxies, (list, tuple)): + raise TypeError("proxies must be a list or tuple") + + for proxy in proxies: + if not isinstance(proxy, ClientRouteProxy): + raise TypeError("All proxies must be ClientRouteProxy instances") + + self.proxies = proxies + self.advanced_shard_awareness = advanced_shard_awareness + + def __repr__(self) -> str: + return (f"ClientRoutesConfig(proxies={self.proxies}, " + f"advanced_shard_awareness={self.advanced_shard_awareness})") + + +@dataclass(frozen=True) +class _Route: + connection_id: str + host_id: uuid.UUID + address: str # ipv4, ipv6 or DNS hostname from system.client_routes + port: int + +class _RouteStore: + """ + Thread-safe storage for routes. Reads are safe under CPython's GIL; + writes are serialized with a lock. + + This uses atomic pointer swaps for updates, allowing lock-free reads + while serializing writes. + """ + + _routes_by_host_id: Dict[uuid.UUID, _Route] + _lock: threading.Lock + + def __init__(self) -> None: + self._routes_by_host_id = {} + self._lock = threading.Lock() + + def get_by_host_id(self, host_id: uuid.UUID) -> Optional[_Route]: + """ + Get route for a host ID (lock-free read). + + :param host_id: UUID of the host + :return: _Route or None + """ + return self._routes_by_host_id.get(host_id) + + def get_all(self) -> List[_Route]: + """ + Get all routes as a list (lock-free read). + + :return: List of _Route + """ + return list(self._routes_by_host_id.values()) + + def _select_preferred_routes(self, new_routes: List[_Route]) -> List[_Route]: + """ + When multiple routes exist for the same host_id (different connection_ids), + prefer the connection_id already in use. Only migrate to a different + connection_id when the previously used one is no longer available. + + Must be called under self._lock. + """ + by_host: Dict[uuid.UUID, List[_Route]] = {} + for route in new_routes: + by_host.setdefault(route.host_id, []).append(route) + + selected = [] + for host_id, candidates in by_host.items(): + if len(candidates) == 1: + selected.append(candidates[0]) + continue + + existing = self._routes_by_host_id.get(host_id) + if existing: + preferred = [c for c in candidates if c.connection_id == existing.connection_id] + if preferred: + selected.append(preferred[0]) + continue + + selected.append(candidates[0]) + + return selected + + def update(self, routes: List[_Route]) -> None: + """ + Replace all routes atomically. + + :param routes: List of _Route objects + """ + with self._lock: + preferred = self._select_preferred_routes(routes) + self._routes_by_host_id = {route.host_id: route for route in preferred} + + def merge(self, new_routes: List[_Route], affected_host_ids: Set[uuid.UUID]) -> None: + """ + Merge new routes with existing ones atomically. + + Routes for affected_host_ids are replaced entirely: existing routes + for those hosts are dropped and replaced with whatever is in new_routes. + This handles deletions from system.client_routes (affected host present + but no new route for it). + + :param new_routes: List of _Route objects to merge + :param affected_host_ids: Set of host IDs affected by the change. + """ + with self._lock: + preferred = self._select_preferred_routes(new_routes) + new_by_host = {r.host_id: r for r in preferred} + + updated = {hid: r for hid, r in self._routes_by_host_id.items() + if hid not in affected_host_ids} + updated.update(new_by_host) + self._routes_by_host_id = updated + + +class _ClientRoutesHandler: + """ + Handles dynamic address translation for Private Link via system.client_routes. + + Lifecycle: + 1. Construction: Create with configuration + 2. Initialization: Read system.client_routes after control connection established + 3. Steady state: Listen for CLIENT_ROUTES_CHANGE events and update routes + 4. Translation: Translate addresses using Host ID lookup + """ + + config: 'ClientRoutesConfig' + ssl_enabled: bool + _routes: _RouteStore + _connection_ids: Set[str] + _proxy_addresses_override: Dict[str, str] + + def __init__(self, config: 'ClientRoutesConfig', ssl_enabled: bool = False): + """ + :param config: ClientRoutesConfig instance + :param ssl_enabled: Whether TLS is enabled (determines port selection) + """ + if not isinstance(config, ClientRoutesConfig): + raise TypeError("config must be a ClientRoutesConfig instance") + + self.config = config + self.ssl_enabled = ssl_enabled + self._routes = _RouteStore() + self._connection_ids = {dep.connection_id for dep in config.proxies} + # Precalculate proxy address mappings for efficient lookup + self._proxy_addresses_override = { + proxy.connection_id: proxy.connection_addr_override + for proxy in config.proxies + if proxy.connection_addr_override + } + + def initialize(self, connection: 'Connection', timeout: float) -> None: + """ + Load all routes from system.client_routes. + + Called once at startup and again whenever the control connection + is re-established. Reads all configured connection IDs and + replaces the in-memory route store atomically. + + Raises on failure so the caller can decide how to react (e.g. + abort startup or schedule a reconnect). + + :param connection: The Connection instance to execute queries on + :param timeout: Query timeout in seconds + """ + log.info("[client routes] Loading routes for %d proxies", len(self.config.proxies)) + + routes = self._query_all_routes_for_connections(connection, timeout, self._connection_ids) + self._routes.update(routes) + + def handle_client_routes_change(self, connection: 'Connection', timeout: float, + change_type: 'ClientRoutesChangeType', + connection_ids: Sequence[str], host_ids: Sequence[str]) -> None: + """ + Handle CLIENT_ROUTES_CHANGE event. + + Currently the protocol defines only :attr:`ClientRoutesChangeType.UPDATE_NODES`. + New variants will be added to the enum if the protocol is extended. + + :param connection: The Connection instance to execute queries on + :param timeout: Query timeout in seconds + :param change_type: A :class:`ClientRoutesChangeType` value + :param connection_ids: Affected connection ID strings; empty means all. + :param host_ids: Affected host ID strings; empty means all. + """ + + full_refresh = False + if not connection_ids or not host_ids: + log.warning( + "[client routes] CLIENT_ROUTES_CHANGE has no connection_ids or host_ids, doing full refresh") + full_refresh = True + elif len(connection_ids) != len(host_ids): + log.warning("[client routes] CLIENT_ROUTES_CHANGE has mismatched lengths (conn: %d, host: %d), doing full refresh", + len(connection_ids), len(host_ids)) + full_refresh = True + + if full_refresh: + routes = self._query_all_routes_for_connections(connection, timeout, self._connection_ids) + self._routes.update(routes) + return + + host_uuids = [uuid.UUID(hid) for hid in host_ids] + pairs = [(cid, hid) for cid, hid in zip(connection_ids, host_uuids) + if cid in self._connection_ids] + + if not pairs: + return + + routes = self._query_routes_for_change_event(connection, timeout, pairs) + self._routes.merge(routes, affected_host_ids=set(host_uuids)) + + def _query_all_routes_for_connections(self, connection: 'Connection', timeout: float, + connection_ids: Set[str]) -> List[_Route]: + """ + Query all routes for the given connection IDs (complete refresh). + + Used when control connection reconnects or as a fallback when + CLIENT_ROUTES_CHANGE event has malformed data. + + :param connection: Connection to execute query on + :param timeout: Query timeout in seconds + :param connection_ids: Set of connection ID strings + :return: List of _Route + """ + if not connection_ids: + return [] + + placeholders = ', '.join('?' for _ in connection_ids) + query = f"SELECT connection_id, host_id, address, port, tls_port FROM system.client_routes WHERE connection_id IN ({placeholders})" + params = [cid.encode('utf-8') for cid in connection_ids] + + log.debug("[client routes] Querying all routes for connection_ids=%s", connection_ids) + return self._execute_routes_query(connection, timeout, query, params) + + def _query_routes_for_change_event(self, connection: 'Connection', timeout: float, + route_pairs: List[Tuple[str, uuid.UUID]]) -> List[_Route]: + """ + Query specific routes affected by a CLIENT_ROUTES_CHANGE event. + + Takes a list of (connection_id, host_id) pairs that represent the exact + routes affected by an operation. This provides precise updates without + fetching unrelated routes. + + If the pairs list is empty or None, falls back to a complete refresh + of all routes for safety. + + :param connection: Connection to execute query on + :param timeout: Query timeout in seconds + :param route_pairs: List of (connection_id, host_id) tuples + :return: List of _Route + """ + unique_pairs = list(dict.fromkeys(route_pairs)) + + conn_ids = list(dict.fromkeys(cid for cid, _ in unique_pairs)) + host_ids = list(dict.fromkeys(hid for _, hid in unique_pairs)) + + log.debug("[client routes] Querying route pairs from CLIENT_ROUTES_CHANGE " + "(first 5 of %d): %s", len(unique_pairs), unique_pairs[:5]) + + conn_ph = ', '.join('?' for _ in conn_ids) + host_ph = ', '.join('?' for _ in host_ids) + query = ( + "SELECT connection_id, host_id, address, port, tls_port " + "FROM system.client_routes " + f"WHERE connection_id IN ({conn_ph}) AND host_id IN ({host_ph})" + ) + params: List = [cid.encode('utf-8') for cid in conn_ids] + params.extend(hid.bytes for hid in host_ids) + + return self._execute_routes_query(connection, timeout, query, params) + + def _execute_routes_query(self, connection: 'Connection', timeout: float, + query: str, params: List) -> List[_Route]: + """ + Execute a routes query and parse results. + + Common helper for both complete refresh and change event queries. + + :param connection: Connection to execute query on + :param timeout: Query timeout in seconds + :param query: CQL query string + :param params: Query parameters + :return: List of _Route + """ + log.debug("[client routes] Executing query: %s with %d parameters", query, len(params)) + + query_msg = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE, + query_params=params if params else None) + result = connection.wait_for_response( + query_msg, timeout=timeout + ) + + routes = [] + broken = 0 + rows = dict_factory(result.column_names, result.parsed_rows) + for row in rows: + try: + absent = [] + port = row['tls_port'] if self.ssl_enabled else row['port'] + connection_id = row['connection_id'] + host_id = row['host_id'] + address = row['address'] + + if not port: + absent.append("tls_port" if self.ssl_enabled else "port") + if not connection_id: + absent.append("connection_id") + if not host_id: + absent.append("host_id") + if not address: + absent.append("address") + + if absent: + log.error("[client routes] read a route %s, that has no values for the following fields: %s", row, ",".join(absent)) + broken += 1 + continue + + final_address = self._proxy_addresses_override.get(connection_id, address) + + routes.append(_Route( + connection_id=connection_id, + host_id=host_id, + address=final_address, + port=port, + )) + except Exception as e: + log.warning("[client routes] Failed to parse route row: %s", e) + broken += 1 + + if broken and not routes: + raise RuntimeError( + "[client routes] All %d route rows failed validation; " + "refusing to return empty result that would wipe the route store" % broken + ) + + return routes + + def resolve_host(self, host_id: uuid.UUID) -> Optional[Tuple[str, int]]: + """ + Resolve a host_id to an (address, port) pair. + + Looks up the current route and selects the appropriate port. + + :param host_id: Host UUID to resolve + :return: Tuple of (address, port) or None if no route mapping exists + """ + route = self._routes.get_by_host_id(host_id) + if route is None: + return None + + if not route.port: + raise ValueError("Mapping for host %s has no port" % host_id) + + try: + result = socket.getaddrinfo(route.address, route.port, + socket.AF_UNSPEC, socket.SOCK_STREAM) + if not result: + raise socket.gaierror("No addresses found for %s" % route.address) + resolved_ip = result[0][4][0] + return resolved_ip, route.port + except socket.gaierror as e: + log.warning('[client routes] Could not resolve hostname "%s" (host_id=%s): %s', + route.address, host_id, e) + raise diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 51d0b2d88b..8da9df6a55 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -29,7 +29,7 @@ from itertools import groupby, count, chain import json import logging -from typing import Optional, Union +from typing import Any, Dict, Optional, Union from warnings import warn from random import random import re @@ -48,7 +48,8 @@ SchemaTargetType, DriverException, ProtocolVersion, UnresolvableContactPoints, DependencyException) from cassandra.auth import _proxy_execute_key, PlainTextAuthProvider -from cassandra.connection import (ConnectionException, ConnectionShutdown, +from cassandra.client_routes import ClientRoutesChangeType, ClientRoutesConfig, _ClientRoutesHandler +from cassandra.connection import (ClientRoutesEndPointFactory, ConnectionException, ConnectionShutdown, ConnectionHeartbeat, ProtocolVersionUnsupported, EndPoint, DefaultEndPoint, DefaultEndPointFactory, SniEndPointFactory, ConnectionBusy, locally_supported_compressions) @@ -1215,7 +1216,8 @@ def __init__(self, shard_aware_options=None, metadata_request_timeout: Optional[float] = None, column_encryption_policy=None, - application_info:Optional[ApplicationInfoBase]=None + application_info:Optional[ApplicationInfoBase]=None, + client_routes_config:Optional[ClientRoutesConfig]=None ): """ ``executor_threads`` defines the number of threads in a pool for handling asynchronous tasks such as @@ -1280,6 +1282,45 @@ def __init__(self, if column_encryption_policy is not None: self.column_encryption_policy = column_encryption_policy + if client_routes_config is not None and endpoint_factory is not None: + raise ValueError("client_routes_config and endpoint_factory are mutually exclusive") + + self._client_routes_handler = None + if client_routes_config is not None: + if not isinstance(client_routes_config, ClientRoutesConfig): + raise TypeError("client_routes_config must be a ClientRoutesConfig instance") + + # SSL hostname verification is incompatible with client routes: + # connections go through NLB proxies whose addresses won't match + # server certificates. + _check_hostname_enabled = False + if ssl_context is not None and ssl_context.check_hostname: + _check_hostname_enabled = True + if ssl_options is not None and ssl_options.get('check_hostname', False): + _check_hostname_enabled = True + if _check_hostname_enabled: + raise ValueError( + "SSL hostname verification (check_hostname=True) is currently incompatible " + "with client_routes_config. When using client routes, connections " + "go through NLB proxies whose addresses won't match server " + "certificates. Disable hostname verification by setting " + "ssl_context.check_hostname = False." + ) + + ssl_enabled = ssl_context is not None or ssl_options is not None + self._client_routes_handler = _ClientRoutesHandler(client_routes_config, ssl_enabled=ssl_enabled) + + if contact_points is _NOT_SET or not self._contact_points_explicit: + seed_addrs = [dep.connection_addr_override for dep in client_routes_config.proxies + if dep.connection_addr_override] + if seed_addrs: + self.contact_points = seed_addrs + self._contact_points_explicit = True + log.info("[client routes] Using %d deployment connection addresses as contact points", + len(seed_addrs)) + + if self._client_routes_handler is not None: + endpoint_factory = ClientRoutesEndPointFactory(self._client_routes_handler, self.port) self.endpoint_factory = endpoint_factory or DefaultEndPointFactory(port=self.port) self.endpoint_factory.configure(self) @@ -1437,6 +1478,10 @@ def __init__(self, self.monitor_reporting_interval = monitor_reporting_interval self.shard_aware_options = ShardAwareOptions(opts=shard_aware_options) + if (client_routes_config is not None + and not client_routes_config.advanced_shard_awareness): + self.shard_aware_options.disable_shardaware_port = True + self._listeners = set() self._listener_lock = Lock() @@ -3612,11 +3657,21 @@ def _try_connect(self, endpoint): # this object (after a dereferencing a weakref) self_weakref = weakref.ref(self, partial(_clear_watcher, weakref.proxy(connection))) try: - connection.register_watchers({ + watchers = { "TOPOLOGY_CHANGE": partial(_watch_callback, self_weakref, '_handle_topology_change'), "STATUS_CHANGE": partial(_watch_callback, self_weakref, '_handle_status_change'), "SCHEMA_CHANGE": partial(_watch_callback, self_weakref, '_handle_schema_change') - }, register_timeout=self._timeout) + } + + if self._cluster._client_routes_handler is not None: + watchers["CLIENT_ROUTES_CHANGE"] = partial(_watch_callback, self_weakref, '_handle_client_routes_change') + + connection.register_watchers(watchers, register_timeout=self._timeout) + + if self._cluster._client_routes_handler is not None: + self._cluster._client_routes_handler.initialize( + connection, + self._timeout) sel_peers = self._get_peers_query(self.PeersQueryType.PEERS, connection) sel_local = self._SELECT_LOCAL if self._token_meta_enabled else self._SELECT_LOCAL_NO_TOKENS @@ -3979,6 +4034,44 @@ def _handle_status_change(self, event): # this will be run by the scheduler self._cluster.on_down(host, is_host_addition=False) + def _handle_client_routes_change(self, event: Dict[str, Any]) -> None: + """ + Handle CLIENT_ROUTES_CHANGE event from the server. + + This event indicates that the system.client_routes table has been updated + and we need to refresh our route mappings. + """ + if self._cluster._client_routes_handler is None: + log.warning("[control connection] Received CLIENT_ROUTES_CHANGE but no handler configured") + return + + raw_change_type = event.get("change_type") + try: + change_type = ClientRoutesChangeType(raw_change_type) + except ValueError: + log.warning("[control connection] Unknown CLIENT_ROUTES_CHANGE type: %s", raw_change_type) + return + + connection_ids = tuple(event.get("connection_ids", [])) + host_ids = tuple(event.get("host_ids", [])) + + self._cluster.scheduler.schedule_unique( + 0, + self._handle_client_routes_refresh, + self._connection, self._timeout, change_type, connection_ids, host_ids + ) + + def _handle_client_routes_refresh(self, connection, timeout, + change_type, connection_ids, host_ids): + try: + self._cluster._client_routes_handler.handle_client_routes_change( + connection, timeout, change_type, connection_ids, host_ids) + except ReferenceError: + pass # our weak reference to the Cluster is no good + except Exception: + log.debug("[control connection] Error handling CLIENT_ROUTES_CHANGE", exc_info=True) + self._signal_error() + def _handle_schema_change(self, event): if self._schema_event_refresh_window < 0: return diff --git a/cassandra/connection.py b/cassandra/connection.py index 87f860f32b..72b273ec37 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -25,12 +25,14 @@ from threading import Thread, Event, RLock, Condition import time import ssl +import uuid import weakref import random import itertools -from typing import Optional, Union +from typing import Any, Dict, Optional, Tuple, Union from cassandra.application_info import ApplicationInfoBase +from cassandra.client_routes import _ClientRoutesHandler from cassandra.protocol_features import ProtocolFeatures if 'gevent.monkey' in sys.modules: @@ -230,7 +232,7 @@ class DefaultEndPointFactory(EndPointFactory): port = None """ If no port is discovered in the row, this is the default port - used for endpoint creation. + used for endpoint creation. """ def __init__(self, port=None): @@ -328,6 +330,50 @@ def create_from_sni(self, sni): return SniEndPoint(self._proxy_address, sni, self._port) +class ClientRoutesEndPointFactory(EndPointFactory): + """ + EndPointFactory for Client Routes (Private Link) support. + + Creates ClientRoutesEndPoint instances that defer both address translation + (host_id -> hostname lookup) and DNS resolution until connection time. + This ensures immediate reaction to infrastructure changes. + """ + + client_routes_handler: _ClientRoutesHandler + default_port: int + + def __init__(self, client_routes_handler: _ClientRoutesHandler, default_port: int = None) -> None: + """ + :param client_routes_handler: _ClientRoutesHandler instance to lookup routes + :param default_port: Default port if none found in row + """ + self.client_routes_handler = client_routes_handler + self.default_port = default_port + + def create(self, row: Dict[str, Any]) -> 'ClientRoutesEndPoint': + """ + Create a ClientRoutesEndPoint from a system.peers row. + + Stores only the host_id and handler reference. Both translation + (route lookup) and DNS resolution happen later in resolve(). + """ + from cassandra.metadata import _NodeInfo + host_id = row.get("host_id") + + if host_id is None: + raise ValueError("No host_id to create ClientRoutesEndPoint") + + addr = _NodeInfo.get_broadcast_rpc_address(row) + port = _NodeInfo.get_broadcast_rpc_port(row) or _NodeInfo.get_broadcast_port(row) or self.default_port + + return ClientRoutesEndPoint( + host_id=host_id, + handler=self.client_routes_handler, + original_address=addr, + original_port=port, + ) + + @total_ordering class UnixSocketEndPoint(EndPoint): """ @@ -369,6 +415,76 @@ def __repr__(self): return "<%s: %s>" % (self.__class__.__name__, self._unix_socket_path) +@total_ordering +class ClientRoutesEndPoint(EndPoint): + """ + Client Routes (Private Link) EndPoint implementation. + + Defers both address translation (route lookup) and DNS resolution + until resolve() is called at connection time. This ensures immediate + reaction to infrastructure changes and CLIENT_ROUTES_CHANGE events. + """ + + _host_id: uuid.UUID + _handler: _ClientRoutesHandler + _original_address: str + _original_port: int + + def __init__(self, host_id: uuid.UUID, handler: _ClientRoutesHandler, original_address: str, original_port: int = None) -> None: + """ + :param host_id: Host UUID for route lookup + :param handler: _ClientRoutesHandler instance + :param original_address: Original address from system.peers (for identification) + :param original_port: Original port if route doesn't specify one + """ + self._host_id = host_id + self._handler = handler + self._original_address = original_address + self._original_port = original_port + + @property + def address(self) -> str: + """Returns the original address (updated by resolve()).""" + return self._original_address + + @property + def port(self) -> Optional[int]: + return self._original_port + + @property + def host_id(self) -> uuid.UUID: + return self._host_id + + def resolve(self) -> Tuple[str, int]: + """ + Resolve endpoint by delegating to the handler. + Falls back to original address/port if no route mapping is available. + """ + result = self._handler.resolve_host(self._host_id) + if result is None: + return self._original_address, self._original_port + return result + + def __eq__(self, other): + return (isinstance(other, ClientRoutesEndPoint) and + self._host_id == other._host_id and + self._original_address == other._original_address) + + def __hash__(self): + return hash((self._host_id, self._original_address)) + + def __lt__(self, other): + return ((self._host_id, self._original_address) < + (other._host_id, other._original_address)) + + def __str__(self): + return str("%s (host_id=%s)" % (self._original_address, self._host_id)) + + def __repr__(self): + return "<%s: host_id=%s, original_addr=%s>" % ( + self.__class__.__name__, self._host_id, self._original_address) + + class _Frame(object): def __init__(self, version, flags, stream, opcode, body_offset, end_pos): self.version = version diff --git a/cassandra/protocol.py b/cassandra/protocol.py index f37633a756..4628c7ee0e 100644 --- a/cassandra/protocol.py +++ b/cassandra/protocol.py @@ -611,9 +611,10 @@ class QueryMessage(_QueryMessage): name = 'QUERY' def __init__(self, query, consistency_level, serial_consistency_level=None, - fetch_size=None, paging_state=None, timestamp=None, continuous_paging_options=None, keyspace=None): + fetch_size=None, paging_state=None, timestamp=None, continuous_paging_options=None, keyspace=None, + query_params=None): self.query = query - super(QueryMessage, self).__init__(None, consistency_level, serial_consistency_level, fetch_size, + super(QueryMessage, self).__init__(query_params, consistency_level, serial_consistency_level, fetch_size, paging_state, timestamp, False, continuous_paging_options, keyspace) def send_body(self, f, protocol_version): diff --git a/cassandra/row_parser.pyx b/cassandra/row_parser.pyx index 88277a4593..41182cd7e1 100644 --- a/cassandra/row_parser.pyx +++ b/cassandra/row_parser.pyx @@ -19,6 +19,92 @@ from cassandra.deserializers import make_deserializers include "ioutils.pyx" +# Maximum number of entries in the ParseDesc cache. Each entry corresponds +# to a distinct PreparedStatement's result_metadata. 256 should be more +# than enough for most applications; when exceeded the entire cache is +# cleared (simple eviction strategy that avoids per-entry bookkeeping). +cdef int _PARSE_DESC_CACHE_MAX_SIZE = 256 + +# Cache for ParseDesc objects keyed by id(column_metadata). +# For prepared statements, result_metadata is stored on PreparedStatement +# and reused across executions, so id() is stable. The cache is only +# populated on the prepared-statement path (where column_metadata comes from +# result_metadata); inline metadata from non-prepared queries is always fresh +# and must not be cached to avoid unbounded growth. +# +# Cache value: (column_metadata_ref, column_encryption_policy_ref, +# protocol_version, desc, column_names, column_types) +# +# Thread safety: individual dict get/set operations are atomic under +# CPython's GIL and under free-threaded builds (PEP 703, which uses +# per-object locks on dicts). The get-miss-build-set sequence is NOT +# transactionally atomic: two threads may both miss and both build a +# ParseDesc for the same key, with the last write winning. This is +# benign -- no data corruption occurs, only redundant construction work +# on the first concurrent miss. No additional locking is needed. +cdef dict _parse_desc_cache = {} + +cdef inline tuple _get_or_build_parse_desc(object column_metadata, object column_encryption_policy, int protocol_version): + """Look up or build a ParseDesc for the given column_metadata (cached path). + + Returns (desc, column_names, column_types). + """ + cdef object cache_key = id(column_metadata) + cdef object cached_or_none = _parse_desc_cache.get(cache_key) + + if cached_or_none is not None: + # Verify identity -- the object at this id must be the same list + # and session-level settings must match. + cached = cached_or_none + if (cached[0] is column_metadata and + cached[1] is column_encryption_policy and + cached[2] == protocol_version): + return (cached[3], cached[4], cached[5]) # hit: (desc, names, types) + + # Cache miss -- build everything + cdef list column_names = [md[2] for md in column_metadata] + cdef list column_types = [md[3] for md in column_metadata] + cdef object desc = ParseDesc( + column_names, column_types, column_encryption_policy, + [ColDesc(md[0], md[1], md[2]) for md in column_metadata], + make_deserializers(column_types), protocol_version) + + # Simple bounded eviction: if the cache is too large, clear it entirely. + # This avoids per-entry bookkeeping (LRU lists, timestamps) that would + # add overhead on the hot cache-hit path. In practice the cache holds + # one entry per prepared statement, so 256 is generous. + if len(_parse_desc_cache) >= _PARSE_DESC_CACHE_MAX_SIZE: + _parse_desc_cache.clear() + + _parse_desc_cache[cache_key] = (column_metadata, column_encryption_policy, + protocol_version, desc, column_names, column_types) + return (desc, column_names, column_types) + + +cdef inline tuple _build_parse_desc(object column_metadata, object column_encryption_policy, int protocol_version): + """Build a ParseDesc without caching (for non-prepared inline metadata). + + Returns (desc, column_names, column_types). + """ + cdef list column_names = [md[2] for md in column_metadata] + cdef list column_types = [md[3] for md in column_metadata] + cdef object desc = ParseDesc( + column_names, column_types, column_encryption_policy, + [ColDesc(md[0], md[1], md[2]) for md in column_metadata], + make_deserializers(column_types), protocol_version) + return (desc, column_names, column_types) + + +def clear_parse_desc_cache(): + """Clear the ParseDesc cache. Exposed for testing.""" + _parse_desc_cache.clear() + + +def get_parse_desc_cache_size(): + """Return the current number of entries in the ParseDesc cache. Exposed for testing.""" + return len(_parse_desc_cache) + + def make_recv_results_rows(ColumnParser colparser): def recv_results_rows(self, f, int protocol_version, user_type_map, result_metadata, column_encryption_policy): """ @@ -29,12 +115,18 @@ def make_recv_results_rows(ColumnParser colparser): column_metadata = self.column_metadata or result_metadata - self.column_names = [md[2] for md in column_metadata] - self.column_types = [md[3] for md in column_metadata] + # Only use the cache for prepared statements (self.column_metadata is + # None, so column_metadata comes from result_metadata which is a + # stable list stored on PreparedStatement). Inline metadata from + # non-prepared queries creates a fresh list every time and would + # cause unbounded cache growth. + if self.column_metadata is None and result_metadata is not None: + desc, self.column_names, self.column_types = _get_or_build_parse_desc( + column_metadata, column_encryption_policy, protocol_version) + else: + desc, self.column_names, self.column_types = _build_parse_desc( + column_metadata, column_encryption_policy, protocol_version) - desc = ParseDesc(self.column_names, self.column_types, column_encryption_policy, - [ColDesc(md[0], md[1], md[2]) for md in column_metadata], - make_deserializers(self.column_types), protocol_version) reader = BytesIOReader(f.read()) try: self.parsed_rows = colparser.parse_rows(reader, desc) diff --git a/docs/conf.py b/docs/conf.py index 403908c29e..4b6b329525 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -29,10 +29,11 @@ '3.29.6-scylla', '3.29.7-scylla', '3.29.8-scylla', + '3.29.9-scylla', ] BRANCHES = ['master'] # Set the latest version. -LATEST_VERSION = '3.29.8-scylla' +LATEST_VERSION = '3.29.9-scylla' # Set which versions are not released yet. UNSTABLE_VERSIONS = ['master'] # Set which versions are deprecated diff --git a/docs/installation.rst b/docs/installation.rst index 4207c46092..7b4823b832 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -26,7 +26,7 @@ To check if the installation was successful, you can run:: python -c 'import cassandra; print(cassandra.__version__)' -It should print something like "3.29.8". +It should print something like "3.29.9". (*Optional*) Compression Support -------------------------------- @@ -199,7 +199,7 @@ through `Homebrew `_. For example, on Mac OS X:: $ brew install libev -The libev extension can now be built for Windows as of Python driver version 3.29.8. You can +The libev extension can now be built for Windows as of Python driver version 3.29.9. You can install libev using any Windows package manager. For example, to install using `vcpkg `_: $ vcpkg install libev diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py index dfac2dc1d9..a53e7aafa6 100644 --- a/tests/integration/__init__.py +++ b/tests/integration/__init__.py @@ -715,6 +715,27 @@ def xfail_scylla_version_lt(reason, oss_scylla_version, ent_scylla_version, *arg return pytest.mark.xfail(current_version < Version(oss_scylla_version), reason=reason, *args, **kwargs) + +def skip_scylla_version_lt(reason, scylla_version): + """ + Skip tests on scylla versions older than the specified thresholds. + :param reason: message explaining why the test is skipped + :param scylla_version: str, version from which test supposed to work + """ + if not (reason.startswith("scylladb/scylladb#") or reason.startswith("scylladb/scylla-enterprise#")): + raise ValueError('reason should start with scylladb/scylladb# or scylladb/scylla-enterprise# to reference issue in scylla repo') + + if not isinstance(scylla_version, str): + raise ValueError('scylla_version should be a str') + + if SCYLLA_VERSION is None: + return pytest.mark.skipif(False, reason="It is just a NoOP Decor, should not skip anything") + + current_version = Version(get_scylla_version(SCYLLA_VERSION)) + + return pytest.mark.skipif(current_version < Version(scylla_version), reason=reason) + + class UpDownWaiter(object): def __init__(self, host): diff --git a/tests/integration/standard/test_client_routes.py b/tests/integration/standard/test_client_routes.py new file mode 100644 index 0000000000..a8a3c30f2c --- /dev/null +++ b/tests/integration/standard/test_client_routes.py @@ -0,0 +1,1314 @@ +# Copyright 2026 ScyllaDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Comprehensive integration tests for Client Routes (Private Link) support. + +Includes: +- TCP proxy and NLB emulator for simulating private link infrastructure +- Tests verifying all connections go exclusively through the proxy +- Tests for dynamic route updates and topology changes +- Tests for query_routes filtering +""" + +import logging +import os +import select +import shutil +import socket +import ssl +import subprocess +import tempfile +import threading +import time +import unittest +import uuid + +import json as _json +import urllib.request + +from cassandra.cluster import Cluster +from cassandra.client_routes import ClientRoutesConfig, ClientRouteProxy +from cassandra.connection import ClientRoutesEndPoint +from cassandra.policies import RoundRobinPolicy +from tests.integration import ( + TestCluster, + get_cluster, + get_node, + use_cluster, + wait_for_node_socket, + skip_scylla_version_lt, +) +from tests.util import wait_until_not_raised + +log = logging.getLogger(__name__) + +class TcpProxy: + """ + A simple TCP proxy that forwards connections from a local listen port + to a target (host, port). Tracks active connections so tests can + verify that traffic flows through the proxy. + """ + + BUF_SIZE = 65536 + + def __init__(self, listen_host, listen_port, target_host, target_port): + self.listen_host = listen_host + self.listen_port = listen_port + self.target_host = target_host + self.target_port = target_port + + self._server_sock = None + self._running = False + self._thread = None + self._lock = threading.Lock() + self._connections = set() + self.total_connections = 0 + + def start(self): + self._server_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self._server_sock.bind((self.listen_host, self.listen_port)) + self.listen_port = self._server_sock.getsockname()[1] + self._server_sock.listen(128) + self._server_sock.setblocking(False) + self._running = True + self._thread = threading.Thread(target=self._run, daemon=True, + name="proxy-%s:%d" % (self.listen_host, self.listen_port)) + self._thread.start() + log.info("TcpProxy started %s:%d -> %s:%d", + self.listen_host, self.listen_port, + self.target_host, self.target_port) + + def stop(self): + self._running = False + if self._server_sock: + try: + self._server_sock.close() + except Exception: + pass + with self._lock: + for csock, tsock in list(self._connections): + self._close_pair(csock, tsock) + self._connections.clear() + if self._thread: + self._thread.join(timeout=5) + log.info("TcpProxy stopped %s:%d", self.listen_host, self.listen_port) + + @property + def active_connections(self): + with self._lock: + return len(self._connections) + + def retarget(self, new_host, new_port): + """Change the backend target for new connections (existing ones keep the old target).""" + self.target_host = new_host + self.target_port = new_port + log.info("TcpProxy %s:%d retargeted to %s:%d", + self.listen_host, self.listen_port, new_host, new_port) + + def drop_connections(self): + """Forcibly close all active connections.""" + with self._lock: + for csock, tsock in list(self._connections): + self._close_pair(csock, tsock) + self._connections.clear() + log.info("TcpProxy %s:%d dropped all connections", self.listen_host, self.listen_port) + + def _run(self): + while self._running: + try: + readable, _, _ = select.select([self._server_sock], [], [], 0.2) + except (ValueError, OSError): + break + for sock in readable: + if sock is self._server_sock: + try: + client_sock, _ = self._server_sock.accept() + except OSError: + continue + self._handle_new_connection(client_sock) + + def _handle_new_connection(self, client_sock, target_host=None, target_port=None): + target_host = target_host or self.target_host + target_port = target_port or self.target_port + try: + target_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + target_sock.connect((target_host, target_port)) + except Exception as e: + log.warning("TcpProxy %s:%d failed to connect to target %s:%d: %s", + self.listen_host, self.listen_port, + target_host, target_port, e) + client_sock.close() + return + + with self._lock: + self._connections.add((client_sock, target_sock)) + self.total_connections += 1 + + t = threading.Thread(target=self._forward_loop, + args=(client_sock, target_sock), + daemon=True) + t.start() + + def _forward_loop(self, client_sock, target_sock): + try: + while self._running: + readable, _, _ = select.select([client_sock, target_sock], [], [], 0.5) + for sock in readable: + data = sock.recv(self.BUF_SIZE) + if not data: + return + if sock is client_sock: + target_sock.sendall(data) + else: + client_sock.sendall(data) + except (OSError, ConnectionResetError, BrokenPipeError): + pass + finally: + with self._lock: + self._connections.discard((client_sock, target_sock)) + self._close_pair(client_sock, target_sock) + + @staticmethod + def _close_pair(csock, tsock): + for s in (csock, tsock): + try: + s.close() + except Exception: + pass + + +class NLBEmulator: + """ + Emulates a Network Load Balancer for a CCM cluster. + + Provides: + - One *discovery port* (round-robin across all live nodes, used as the + driver's ``contact_points``). + - One *per-node port* for each node (dedicated proxy to that node's + native transport port). + + All proxies listen on ``LISTEN_HOST`` (127.254.254.101), an address + outside the CCM node range, simulating a real NLB endpoint. + + Port layout (all ports are OS-assigned by default): + LISTEN_HOST:discovery_port -> round-robin to all live nodes + LISTEN_HOST: -> node1 (127.0.0.1:9042) + LISTEN_HOST: -> node2 (127.0.0.2:9042) + ... + + Automatically creates/removes per-node proxies when nodes are + added/removed so CCM cluster operations are reflected seamlessly. + """ + + LISTEN_HOST = "127.254.254.101" + + def __init__(self, discovery_port=0, + per_node_base=0, + native_port=9042, + node_addresses=None): + self.discovery_port = discovery_port + self.per_node_base = per_node_base + self.native_port = native_port + self._deferred_node_addresses = node_addresses + + self._node_proxies = {} + self._discovery_proxy = None + self._rr_index = 0 + self._lock = threading.Lock() + self._running = False + + def start(self, node_addresses): + """ + Start the NLB with an initial set of node addresses. + + :param node_addresses: dict of node_id -> ip_address, e.g. + {1: "127.0.0.1", 2: "127.0.0.2"} + """ + self._running = True + try: + for node_id, addr in node_addresses.items(): + self._add_node_proxy(node_id, addr) + + first_addr = list(node_addresses.values())[0] + self._discovery_proxy = TcpProxy( + self.LISTEN_HOST, self.discovery_port, + first_addr, self.native_port, + ) + self._discovery_proxy.start() + self.discovery_port = self._discovery_proxy.listen_port + except Exception: + self.stop() + raise + original_handler = self._discovery_proxy._handle_new_connection + + def rr_handler(client_sock): + addrs = self._live_addresses() + if not addrs: + client_sock.close() + return + idx = self._rr_index % len(addrs) + self._rr_index += 1 + addr = addrs[idx] + original_handler(client_sock, target_host=addr, target_port=self.native_port) + + self._discovery_proxy._handle_new_connection = rr_handler + + log.info("NLB started: discovery=%s:%d, %d node proxies", + self.LISTEN_HOST, self.discovery_port, len(self._node_proxies)) + return self + + def __enter__(self): + if not self._running and self._deferred_node_addresses is not None: + self.start(self._deferred_node_addresses) + return self + + def __exit__(self, *args): + self.stop() + + def stop(self): + self._running = False + if self._discovery_proxy: + self._discovery_proxy.stop() + for proxy in self._node_proxies.values(): + proxy.stop() + self._node_proxies.clear() + log.info("NLB stopped") + + def add_node(self, node_id, addr): + self._add_node_proxy(node_id, addr) + + def remove_node(self, node_id): + with self._lock: + proxy = self._node_proxies.pop(node_id, None) + if proxy: + proxy.stop() + log.info("NLB removed node %d", node_id) + + def node_port(self, node_id): + proxy = self._node_proxies.get(node_id) + if proxy: + return proxy.listen_port + return self.per_node_base + node_id + + def get_node_proxy(self, node_id): + return self._node_proxies.get(node_id) + + def total_proxy_connections(self): + return sum(p.total_connections for p in self._node_proxies.values()) + + def active_proxy_connections(self): + return sum(p.active_connections for p in self._node_proxies.values()) + + def drop_all_connections(self): + for proxy in self._node_proxies.values(): + proxy.drop_connections() + if self._discovery_proxy: + self._discovery_proxy.drop_connections() + + def _add_node_proxy(self, node_id, addr): + port = 0 + proxy = TcpProxy(self.LISTEN_HOST, port, addr, self.native_port) + proxy.start() + with self._lock: + self._node_proxies[node_id] = proxy + log.info("NLB added node %d: %s:%d -> %s:%d", + node_id, self.LISTEN_HOST, port, addr, self.native_port) + + def _live_addresses(self): + """IPs of nodes with active proxies.""" + return [p.target_host for p in self._node_proxies.values()] + +def post_client_routes(contact_point, routes): + """ + Post client routes to Scylla's REST API. + + :param contact_point: IP/hostname of a Scylla node (e.g. "127.0.0.1") + :param routes: List of route dicts with keys: connection_id, host_id, address, port + and optionally tls_port + """ + payload = [] + for route in routes: + entry = { + "connection_id": str(route["connection_id"]), + "host_id": str(route["host_id"]), + "address": route["address"], + "port": route["port"], + } + if route.get("tls_port") is not None: + entry["tls_port"] = route["tls_port"] + payload.append(entry) + + url = "http://%s:10000/v2/client-routes" % contact_point + log.info("Posting %d routes to %s", len(payload), url) + data = _json.dumps(payload).encode("utf-8") + req = urllib.request.Request( + url, + data=data, + headers={ + "Content-Type": "application/json", + "Accept": "application/json", + }, + method="POST", + ) + response = urllib.request.urlopen(req) + log.info("Routes posted successfully (status %d)", response.status) + + +def get_host_ids_from_cluster(session): + """ + Build a mapping of rpc_address -> host_id for all nodes in the cluster. + + Uses the driver's metadata rather than querying system.local / system.peers + directly, because those queries can be routed to different coordinators + (system.local returns the coordinator's own info while system.peers omits + the coordinator), leading to a node being missing from the map. + """ + host_id_map = {} + for host in session.cluster.metadata.all_hosts(): + host_id_map[host.address] = host.host_id + return host_id_map + + +def build_routes_for_nlb(connection_id, host_id_map, nlb): + """ + Build routes that direct each host_id through the NLB per-node proxy. + + :param connection_id: Connection ID string + :param host_id_map: dict ip -> uuid host_id (from get_host_ids_from_cluster) + :param nlb: NLBEmulator instance + :return: list of route dicts + """ + routes = [] + for ip, host_id in host_id_map.items(): + node_id = int(ip.split(".")[-1]) + port = nlb.node_port(node_id) + routes.append({ + "connection_id": connection_id, + "host_id": host_id, + "address": NLBEmulator.LISTEN_HOST, + "port": port, + }) + return routes + + +def post_routes_for_nlb(contact_point, connection_id, host_id_map, nlb): + """Build routes for the NLB and POST them via the REST API.""" + routes = build_routes_for_nlb(connection_id, host_id_map, nlb) + post_client_routes(contact_point, routes) + return routes + +def wait_for_routes_visible(session, connection_id, expected_count, timeout=10, poll_interval=0.1): + """ + Poll system.client_routes on **every** node until each one sees at + least *expected_count* rows for *connection_id*. + + ``system.client_routes`` is a node-local table, so routes posted via + the REST API to one node are not guaranteed to be visible on the + others at the same time. This helper ensures they have propagated + everywhere before the test proceeds. + + :param session: an active driver Session (direct, not through NLB) + :param connection_id: the connection_id string to filter on + :param expected_count: how many rows we expect to see per node + :param timeout: maximum seconds to wait + :param poll_interval: seconds between polls + """ + all_hosts = list(session.cluster.metadata.all_hosts()) + deadline = time.time() + timeout + while True: + pending_hosts = [] + for host in all_hosts: + rows = list(session.execute( + "SELECT * FROM system.client_routes WHERE connection_id = %s", + (connection_id,), + host=host, + )) + if len(rows) < expected_count: + pending_hosts.append((host, len(rows))) + if not pending_hosts: + return + if time.time() >= deadline: + details = ", ".join( + "%s: %d" % (h.address, count) for h, count in pending_hosts + ) + raise RuntimeError( + "Timed out waiting for %d routes (connection_id=%s) to appear " + "in system.client_routes on all nodes; pending: %s" + % (expected_count, connection_id, details) + ) + time.sleep(poll_interval) + + +def node_id_from_ip(ip): + """Extract node_id from an IP like '127.0.0.3' -> 3.""" + return int(ip.split(".")[-1]) + + +def assert_routes_via_nlb(test, cluster, nlb, expected_node_ids): + """ + Assert that every host in *expected_node_ids* has its endpoint + resolving through the NLB (correct address and per-node port). + """ + nlb_listen_host = NLBEmulator.LISTEN_HOST + expected_node_ids = set(expected_node_ids) + + seen_node_ids = set() + for host in cluster.metadata.all_hosts(): + ep = host.endpoint + if not isinstance(ep, ClientRoutesEndPoint): + continue + node_id = node_id_from_ip(ep.address) + if node_id not in expected_node_ids: + continue + resolved_addr, resolved_port = ep.resolve() + test.assertEqual( + resolved_addr, nlb_listen_host, + "Node %d endpoint should resolve to NLB address %s, got %s" + % (node_id, nlb_listen_host, resolved_addr), + ) + test.assertEqual( + resolved_port, nlb.node_port(node_id), + "Node %d endpoint should resolve to NLB port %d, got %d" + % (node_id, nlb.node_port(node_id), resolved_port), + ) + seen_node_ids.add(node_id) + test.assertEqual( + seen_node_ids, expected_node_ids, + "Not all expected nodes found in metadata endpoints", + ) + + +def assert_routes_direct(test, cluster, expected_node_ids, direct_port=9042): + """ + Assert that every host in *expected_node_ids* has its endpoint + resolving to the node's own IP on *direct_port*. + """ + expected_node_ids = set(expected_node_ids) + + for host in cluster.metadata.all_hosts(): + ep = host.endpoint + if not isinstance(ep, ClientRoutesEndPoint): + continue + node_id = node_id_from_ip(ep.address) + if node_id not in expected_node_ids: + continue + resolved_addr, resolved_port = ep.resolve() + expected_ip = "127.0.0.%d" % node_id + test.assertEqual( + resolved_addr, expected_ip, + "Node %d endpoint should resolve to direct address %s, got %s" + % (node_id, expected_ip, resolved_addr), + ) + test.assertEqual( + resolved_port, direct_port, + "Node %d endpoint should resolve to direct port %d, got %d" + % (node_id, direct_port, resolved_port), + ) + + +def setup_module(): + os.environ['SCYLLA_EXT_OPTS'] = "--smp 2 --memory 2048M" + use_cluster('test_client_routes', [3], start=True) + +@skip_scylla_version_lt(reason='scylladb/scylladb#26992 - system.client_routes is not yet supported', + scylla_version="2026.1.0") +class TestGetHostPortMapping(unittest.TestCase): + """ + Test _query_all_routes_for_connections and _query_routes_for_change_event + methods with different filtering scenarios. + """ + + @classmethod + def setUpClass(cls): + cls.cluster = TestCluster(client_routes_config=ClientRoutesConfig( + proxies=[ClientRouteProxy("conn_id", "127.0.0.1")])) + cls.session = cls.cluster.connect() + + cls.host_ids = [uuid.uuid4() for _ in range(3)] + cls.connection_ids = [str(uuid.uuid4()) for _ in range(3)] + cls.expected = [] + + for idx, host_id in enumerate(cls.host_ids): + ip = f"127.0.0.{idx + 1}" + for connection_id in cls.connection_ids: + cls.expected.append({ + 'connection_id': connection_id, + 'host_id': host_id, + 'address': ip, + 'port': 9042, + 'tls_port': 9142, + }) + + cls._sort_routes(cls.expected) + post_client_routes(cls.cluster.contact_points[0], cls.expected) + + @classmethod + def tearDownClass(cls): + cls.cluster.shutdown() + + @staticmethod + def _sort_routes(routes): + routes.sort(key=lambda r: (str(r['connection_id']), str(r['host_id']))) + + def _routes_to_dicts(self, routes): + """Convert _Route objects to comparable dicts, adjusting port for ssl_enabled.""" + return [ + { + 'connection_id': route.connection_id, + 'host_id': route.host_id, + 'address': route.address, + 'port': route.port, + } + for route in routes + ] + + def _expected_dicts(self, expected): + """Build expected dicts with tls_port or port based on ssl_enabled.""" + port_key = 'tls_port' if self.cluster._client_routes_handler.ssl_enabled else 'port' + return [ + { + 'connection_id': e['connection_id'], + 'host_id': e['host_id'], + 'address': e['address'], + 'port': e[port_key], + } + for e in expected + ] + + def test_get_all_routes_for_all_connections(self): + """Querying all connection IDs returns every route.""" + cc = self.cluster.control_connection + routes = self.cluster._client_routes_handler._query_all_routes_for_connections( + cc._connection, cc._timeout, self.connection_ids, + ) + got = self._routes_to_dicts(routes) + self._sort_routes(got) + expected = self._expected_dicts(self.expected) + self._sort_routes(expected) + self.assertEqual(got, expected) + + def test_get_routes_for_single_connection(self): + """Querying a single connection ID returns only its routes.""" + cc = self.cluster.control_connection + routes = self.cluster._client_routes_handler._query_all_routes_for_connections( + cc._connection, cc._timeout, [self.connection_ids[0]], + ) + got = self._routes_to_dicts(routes) + self._sort_routes(got) + filtered = [r for r in self.expected + if r['connection_id'] == self.connection_ids[0]] + expected = self._expected_dicts(filtered) + self._sort_routes(expected) + self.assertEqual(got, expected) + + def test_get_routes_for_change_event_all_pairs(self): + """Querying all (connection_id, host_id) pairs returns every route.""" + cc = self.cluster.control_connection + pairs = [(r['connection_id'], r['host_id']) for r in self.expected] + routes = self.cluster._client_routes_handler._query_routes_for_change_event( + cc._connection, cc._timeout, pairs, + ) + got = self._routes_to_dicts(routes) + self._sort_routes(got) + expected = self._expected_dicts(self.expected) + self._sort_routes(expected) + self.assertEqual(got, expected) + + def test_get_routes_for_change_event_single_pair(self): + """Querying a single (connection_id, host_id) pair returns one route.""" + cc = self.cluster.control_connection + target_conn_id = self.connection_ids[0] + target_host_id = self.host_ids[0] + routes = self.cluster._client_routes_handler._query_routes_for_change_event( + cc._connection, cc._timeout, [(target_conn_id, target_host_id)], + ) + got = self._routes_to_dicts(routes) + self._sort_routes(got) + filtered = [r for r in self.expected + if r['connection_id'] == target_conn_id + and r['host_id'] == target_host_id] + expected = self._expected_dicts(filtered) + self._sort_routes(expected) + self.assertEqual(got, expected) + +@skip_scylla_version_lt(reason='scylladb/scylladb#26992 - system.client_routes is not yet supported', + scylla_version="2026.1.0") +class TestPrivateLinkConnectivity(unittest.TestCase): + """ + Verifies the driver connects to all cluster nodes exclusively through + the NLB proxy, never directly. + + Setup: + 1. Start a 3-node CCM cluster (done by setup_module). + 2. Start an NLB emulator with per-node proxies. + 3. Use a direct session to read host_ids, then POST client routes + pointing each host_id at the NLB proxy port. + 4. Create a client-routes-enabled session using the NLB discovery + port as the contact point. + 5. Verify all driver connections go through proxy ports. + """ + + @classmethod + def setUpClass(cls): + cls.direct_cluster = TestCluster() + cls.direct_session = cls.direct_cluster.connect() + cls.host_id_map = get_host_ids_from_cluster(cls.direct_session) + log.info("Host ID map: %s", cls.host_id_map) + + cls.node_addrs = {} + for ip in cls.host_id_map: + node_id = int(ip.split(".")[-1]) + cls.node_addrs[node_id] = ip + + cls.nlb = NLBEmulator() + cls.nlb.start(cls.node_addrs) + + cls.connection_id = str(uuid.uuid4()) + post_routes_for_nlb("127.0.0.1", cls.connection_id, cls.host_id_map, cls.nlb) + wait_for_routes_visible(cls.direct_session, cls.connection_id, len(cls.host_id_map)) + + @classmethod + def tearDownClass(cls): + cls.direct_cluster.shutdown() + cls.nlb.stop() + + def _make_client_routes_cluster(self, **extra_kwargs): + """Create a Cluster configured with client-routes pointing at the NLB.""" + return Cluster( + contact_points=[NLBEmulator.LISTEN_HOST], + port=self.nlb.discovery_port, + client_routes_config=ClientRoutesConfig( + proxies=[ClientRouteProxy(self.connection_id, NLBEmulator.LISTEN_HOST)], + ), + load_balancing_policy=RoundRobinPolicy(), + **extra_kwargs, + ) + + def test_all_connections_through_proxy(self): + """Every pool connection must go through the NLB proxy, not directly.""" + with self._make_client_routes_cluster() as cluster: + session = cluster.connect(wait_for_all_pools=True) + + for _ in range(50): + session.execute("SELECT key FROM system.local") + + pool_state = session.get_pool_state() + self.assertEqual(len(pool_state), len(self.node_addrs), + "Driver should have pools for all nodes") + + for host, state in pool_state.items(): + node_id = node_id_from_ip(host.address) + proxy = self.nlb.get_node_proxy(node_id) + self.assertIsNotNone(proxy, f"No proxy for node {node_id}") + open_count = state['open_count'] + self.assertGreaterEqual( + proxy.total_connections, open_count, + f"Node {node_id} proxy saw {proxy.total_connections} " + f"connections but pool has {open_count} open — " + f"some connections bypassed the proxy") + + assert_routes_via_nlb(self, cluster, self.nlb, + self.node_addrs.keys()) + + def test_queries_succeed_through_proxy(self): + """Queries should work normally through the proxy.""" + with self._make_client_routes_cluster() as cluster: + session = cluster.connect() + session.execute( + "CREATE KEYSPACE IF NOT EXISTS test_cr_ks " + "WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 3}" + ) + session.execute( + "CREATE TABLE IF NOT EXISTS test_cr_ks.t (k int PRIMARY KEY, v text)" + ) + session.execute("INSERT INTO test_cr_ks.t (k, v) VALUES (1, 'hello')") + row = session.execute("SELECT v FROM test_cr_ks.t WHERE k = 1").one() + self.assertEqual(row.v, "hello") + + assert_routes_via_nlb(self, cluster, self.nlb, + self.node_addrs.keys()) + + def test_connection_recovery_after_proxy_drop(self): + """ + After the proxy drops all connections, the driver should reconnect + (still through the proxy). + """ + with self._make_client_routes_cluster() as cluster: + session = cluster.connect(wait_for_all_pools=True) + session.execute("SELECT key FROM system.local") + + assert_routes_via_nlb(self, cluster, self.nlb, + self.node_addrs.keys()) + + self.nlb.drop_all_connections() + + def query_ok(): + session.execute("SELECT key FROM system.local") + + wait_until_not_raised(query_ok, 1, 30) + + assert_routes_via_nlb(self, cluster, self.nlb, + self.node_addrs.keys()) + +@skip_scylla_version_lt(reason='scylladb/scylladb#26992 - system.client_routes is not yet supported', + scylla_version="2026.1.0") +class TestDynamicRouteUpdates(unittest.TestCase): + """ + Verify that when routes are updated (e.g. port changes), the driver + picks up the new routes and reconnects through the new proxy ports + after existing connections are dropped. + """ + + @classmethod + def setUpClass(cls): + cls.direct_cluster = TestCluster() + cls.direct_session = cls.direct_cluster.connect() + cls.host_id_map = get_host_ids_from_cluster(cls.direct_session) + + cls.node_addrs = {} + for ip in cls.host_id_map: + node_id = int(ip.split(".")[-1]) + cls.node_addrs[node_id] = ip + + cls.connection_id = str(uuid.uuid4()) + + @classmethod + def tearDownClass(cls): + cls.direct_cluster.shutdown() + + def test_route_update_causes_reconnect_to_new_port(self): + """ + 1. Start NLB v1, post routes -> driver connects through v1 ports. + 2. Start NLB v2 on different ports, post new routes. + 3. Drop v1 connections. + 4. Driver should reconnect through v2 ports. + """ + with NLBEmulator( + node_addresses=self.node_addrs, + ) as nlb_v1, NLBEmulator( + node_addresses=self.node_addrs, + ) as nlb_v2: + post_routes_for_nlb("127.0.0.1", self.connection_id, + self.host_id_map, nlb_v1) + wait_for_routes_visible(self.direct_session, self.connection_id, len(self.host_id_map)) + + with Cluster( + contact_points=[NLBEmulator.LISTEN_HOST], + port=nlb_v1.discovery_port, + client_routes_config=ClientRoutesConfig( + proxies=[ClientRouteProxy(self.connection_id, NLBEmulator.LISTEN_HOST)], + ), + load_balancing_policy=RoundRobinPolicy(), + ) as cluster: + session = cluster.connect(wait_for_all_pools=True) + session.execute("SELECT key FROM system.local") + + for node_id in self.node_addrs: + self.assertGreater( + nlb_v1.get_node_proxy(node_id).total_connections, 0) + assert_routes_via_nlb(self, cluster, nlb_v1, + self.node_addrs.keys()) + + post_routes_for_nlb("127.0.0.1", self.connection_id, + self.host_id_map, nlb_v2) + time.sleep(2) # let CLIENT_ROUTES_CHANGE propagate + + # Stop v1 per-node proxies entirely so v1 ports become + # unreachable, forcing the driver to reconnect through v2. + # (Merely dropping connections is insufficient because v1 + # proxies would still accept new connections before the + # route update propagates.) + for node_id in list(self.node_addrs.keys()): + nlb_v1.remove_node(node_id) + + def all_nodes_via_v2(): + session.execute("SELECT key FROM system.local") + for nid in self.node_addrs: + assert nlb_v2.get_node_proxy(nid).total_connections > 0, \ + "NLB v2 node %d proxy has no connections yet" % nid + + wait_until_not_raised(all_nodes_via_v2, 1, 30) + + assert_routes_via_nlb(self, cluster, nlb_v2, + self.node_addrs.keys()) + + +def _generate_ssl_certs(cert_dir, node_ips): + """ + Generate test SSL certificates with SANs covering the given node IPs. + + File names follow CCM's ``ScyllaCluster.enable_ssl()`` convention so the + resulting directory can be passed directly to ``enable_ssl(cert_dir, ...)``. + + Creates: + - ca.key / ca.crt: self-signed CA + - ccm_node.key / ccm_node.pem: server cert signed by CA with SANs for all node_ips + + :param cert_dir: directory to write files into (must exist) + :param node_ips: list of IP strings to include as SANs (e.g. ["127.0.0.1", "127.0.0.2"]) + """ + if shutil.which("openssl") is None: + raise unittest.SkipTest("openssl not found on PATH; skipping SSL cert generation") + + san_cnf = os.path.join(cert_dir, "san.cnf") + san_value = ",".join("IP:%s" % ip for ip in node_ips) + with open(san_cnf, "w") as f: + f.write("subjectAltName=%s\n" % san_value) + + def _run(cmd): + result = subprocess.run(cmd, cwd=cert_dir, capture_output=True, text=True) + if result.returncode != 0: + raise RuntimeError("Command failed: %s\n%s" % (" ".join(cmd), result.stderr)) + + _run(["openssl", "req", "-x509", "-newkey", "rsa:2048", + "-keyout", "ca.key", "-out", "ca.crt", + "-days", "1", "-nodes", "-subj", "/CN=Test CA"]) + + _run(["openssl", "req", "-newkey", "rsa:2048", + "-keyout", "ccm_node.key", "-out", "ccm_node.csr", + "-nodes", "-subj", "/CN=Test Server"]) + + _run(["openssl", "x509", "-req", + "-in", "ccm_node.csr", "-CA", "ca.crt", "-CAkey", "ca.key", + "-CAcreateserial", "-out", "ccm_node.pem", + "-days", "1", "-extfile", "san.cnf"]) + + log.info("Generated SSL certs in %s with SANs: %s", cert_dir, san_value) + + +@skip_scylla_version_lt(reason='scylladb/scylladb#26992 - system.client_routes is not yet supported', + scylla_version="2026.1.0") +class TestMixedDirectAndNlbConnections(unittest.TestCase): + """ + Verify the cluster works when some nodes are accessed through the NLB + proxy and others are accessed directly (no route posted, falls back + to the default endpoint). + """ + + @classmethod + def setUpClass(cls): + cls.direct_cluster = TestCluster() + cls.direct_session = cls.direct_cluster.connect() + cls.host_id_map = get_host_ids_from_cluster(cls.direct_session) + + cls.node_addrs = {} + for ip in cls.host_id_map: + node_id = int(ip.split(".")[-1]) + cls.node_addrs[node_id] = ip + + cls.connection_id = str(uuid.uuid4()) + + @classmethod + def tearDownClass(cls): + cls.direct_cluster.shutdown() + + def test_mixed_direct_and_nlb_connections(self): + """ + Post routes for only a subset of nodes (through NLB proxy). + Remaining nodes have no route and fall back to direct connections. + Queries should work through both paths. + """ + proxied_node_id = min(self.node_addrs.keys()) + proxied_ip = self.node_addrs[proxied_node_id] + + with NLBEmulator( + node_addresses={proxied_node_id: proxied_ip}, + ) as nlb: + proxied_host_id = self.host_id_map[proxied_ip] + routes = [{ + "connection_id": self.connection_id, + "host_id": proxied_host_id, + "address": NLBEmulator.LISTEN_HOST, + "port": nlb.node_port(proxied_node_id), + }] + post_client_routes("127.0.0.1", routes) + time.sleep(1) + + with Cluster( + contact_points=["127.0.0.1"], + client_routes_config=ClientRoutesConfig( + proxies=[ClientRouteProxy(self.connection_id, NLBEmulator.LISTEN_HOST)], + ), + load_balancing_policy=RoundRobinPolicy(), + ) as cluster: + session = cluster.connect(wait_for_all_pools=True) + + for _ in range(50): + session.execute("SELECT key FROM system.local") + + assert_routes_via_nlb(self, cluster, nlb, + [proxied_node_id]) + + direct_node_ids = set(self.node_addrs.keys()) - {proxied_node_id} + assert_routes_direct(self, cluster, direct_node_ids) + + proxy = nlb.get_node_proxy(proxied_node_id) + self.assertGreater(proxy.total_connections, 0, + "Proxied node should have connections through NLB") + + +@skip_scylla_version_lt(reason='scylladb/scylladb#26992 - system.client_routes is not yet supported', + scylla_version="2026.1.0") +class TestSslThroughNlb(unittest.TestCase): + """ + Verify SSL with check_hostname=False works through the NLB proxy. + + When using client routes, connections go through NLB proxies whose + addresses won't match server certificates, so hostname verification + must be disabled. Certificate chain validation (verify_mode=CERT_REQUIRED) + is still active — only hostname matching is skipped. + + The driver raises ValueError at Cluster init time if check_hostname=True + is used with client_routes_config. + """ + + @classmethod + def setUpClass(cls): + cls.direct_cluster = TestCluster() + cls.direct_session = cls.direct_cluster.connect() + cls.host_id_map = get_host_ids_from_cluster(cls.direct_session) + cls.direct_cluster.shutdown() + + cls.node_addrs = {} + for ip in cls.host_id_map: + node_id = int(ip.split(".")[-1]) + cls.node_addrs[node_id] = ip + + cls.connection_id = str(uuid.uuid4()) + + cls.cert_dir = tempfile.mkdtemp(prefix="client-routes-ssl-") + cert_ips = list(cls.node_addrs.values()) + _generate_ssl_certs(cls.cert_dir, cert_ips) + + cls.ccm_cluster = get_cluster() + cls.ccm_cluster.stop() + cls.ccm_cluster.set_configuration_options({ + 'client_encryption_options': { + 'enabled': True, + 'certificate': os.path.join(cls.cert_dir, "ccm_node.pem"), + 'keyfile': os.path.join(cls.cert_dir, "ccm_node.key"), + } + }) + cls.ccm_cluster.start(wait_for_binary_proto=True) + + @classmethod + def tearDownClass(cls): + cls.ccm_cluster.stop() + cls.ccm_cluster.set_configuration_options({ + 'client_encryption_options': { + 'enabled': False, + } + }) + cls.ccm_cluster.start(wait_for_binary_proto=True) + + shutil.rmtree(cls.cert_dir, ignore_errors=True) + + def test_ssl_without_hostname_verification_through_nlb(self): + """ + Connect through NLB with SSL but check_hostname=False. + + When using client routes, connections go through NLB proxies + whose addresses won't match server certificates, so hostname + verification must be disabled. Certificate chain validation + (verify_mode=CERT_REQUIRED) is still active. + """ + with NLBEmulator( + node_addresses=self.node_addrs, + ) as nlb: + routes = build_routes_for_nlb( + self.connection_id, self.host_id_map, nlb, + ) + for route in routes: + route["tls_port"] = route["port"] + post_client_routes("127.0.0.1", routes) + + ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ssl_ctx.check_hostname = False + ssl_ctx.load_verify_locations(os.path.join(self.cert_dir, 'ca.crt')) + + self.assertFalse(ssl_ctx.check_hostname, + "check_hostname must be False for this test") + self.assertEqual(ssl_ctx.verify_mode, ssl.CERT_REQUIRED, + "verify_mode must be CERT_REQUIRED") + + def routes_visible(): + with TestCluster( + contact_points=["127.0.0.1"], + ssl_context=ssl_ctx, + ) as c: + session = c.connect() + rs = session.execute( + "SELECT * FROM system.client_routes " + "WHERE connection_id = %s ALLOW FILTERING", + (self.connection_id,) + ) + return len(list(rs)) >= len(self.host_id_map) + + wait_until_not_raised( + lambda: self.assertTrue(routes_visible()), + 0.5, 10, + ) + + with Cluster( + contact_points=[NLBEmulator.LISTEN_HOST], + port=nlb.discovery_port, + ssl_context=ssl_ctx, + client_routes_config=ClientRoutesConfig( + proxies=[ClientRouteProxy(self.connection_id, NLBEmulator.LISTEN_HOST)], + ), + load_balancing_policy=RoundRobinPolicy(), + ) as cluster: + session = cluster.connect(wait_for_all_pools=True) + + for _ in range(20): + row = session.execute( + "SELECT release_version FROM system.local" + ).one() + self.assertIsNotNone(row) + + assert_routes_via_nlb(self, cluster, nlb, + self.node_addrs.keys()) + + def test_ssl_with_hostname_verification_raises_error(self): + """ + Verify that Cluster raises ValueError when client_routes_config + is used with SSL hostname verification enabled. + """ + ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ssl_ctx.load_verify_locations(os.path.join(self.cert_dir, 'ca.crt')) + self.assertTrue(ssl_ctx.check_hostname) + + with self.assertRaises(ValueError) as cm: + Cluster( + contact_points=[NLBEmulator.LISTEN_HOST], + ssl_context=ssl_ctx, + client_routes_config=ClientRoutesConfig( + proxies=[ClientRouteProxy("test-id", NLBEmulator.LISTEN_HOST)], + ), + ) + self.assertIn("check_hostname", str(cm.exception)) + +@skip_scylla_version_lt(reason='scylladb/scylladb#26992 - system.client_routes is not yet supported', + scylla_version="2026.1.0") +class TestFullNodeReplacementThroughNlb(unittest.TestCase): + """ + End-to-end test: creates a session through an NLB proxy with client routes, + scales the cluster up, then decommissions original nodes, verifying the + session survives the full node replacement. + + This test is destructive — it modifies the CCM cluster topology by + bootstrapping new nodes and decommissioning original ones. It uses + its own CCM cluster so it cannot interfere with other tests. + """ + + @classmethod + def setUpClass(cls): + os.environ['SCYLLA_EXT_OPTS'] = "--smp 2 --memory 2048M" + use_cluster('test_client_routes_replacement', [3], start=True) + + cls.direct_cluster = TestCluster() + cls.direct_session = cls.direct_cluster.connect() + cls.host_id_map = get_host_ids_from_cluster(cls.direct_session) + + cls.node_addrs = {} + for ip in cls.host_id_map: + node_id = int(ip.split(".")[-1]) + cls.node_addrs[node_id] = ip + + cls.connection_id = str(uuid.uuid4()) + + @classmethod + def tearDownClass(cls): + cls.direct_cluster.shutdown() + + def test_should_survive_full_node_replacement_through_nlb(self): + """ + 1. Start with 3 nodes behind the NLB + 2. Bootstrap 2 new nodes, add to NLB, update routes + 3. Decommission the original 3 nodes one-by-one, updating NLB/routes + 4. Verify the session survives with only new nodes + """ + original_node_ids = sorted(self.node_addrs.keys()) + with NLBEmulator( + node_addresses=self.node_addrs, + ) as nlb: + # ---- Stage 1: Set up NLB for initial nodes ---- + log.info("Stage 1: Setting up NLB for %d initial nodes", len(original_node_ids)) + + post_routes_for_nlb("127.0.0.1", self.connection_id, self.host_id_map, nlb) + wait_for_routes_visible(self.direct_session, self.connection_id, len(self.host_id_map)) + + # ---- Stage 2: Create session through NLB ---- + log.info("Stage 2: Creating session through NLB") + with Cluster( + contact_points=[NLBEmulator.LISTEN_HOST], + port=nlb.discovery_port, + client_routes_config=ClientRoutesConfig( + proxies=[ClientRouteProxy(self.connection_id, NLBEmulator.LISTEN_HOST)], + ), + load_balancing_policy=RoundRobinPolicy(), + ) as cluster: + session = cluster.connect(wait_for_all_pools=True) + self._assert_query_works(session) + + handler = cluster._client_routes_handler + self.assertIsNotNone(handler) + + assert_routes_via_nlb(self, cluster, nlb, + original_node_ids) + log.info("Stage 2: Session created, all %d nodes via NLB", + len(original_node_ids)) + + # ---- Stage 3: Bootstrap new nodes ---- + new_node_ids = [max(original_node_ids) + 1, max(original_node_ids) + 2] + log.info("Stage 3: Adding nodes %s", new_node_ids) + ccm_cluster = get_cluster() + + for node_id in new_node_ids: + self._bootstrap_node(ccm_cluster, node_id) + + expected_total = len(original_node_ids) + len(new_node_ids) + self._wait_for_condition( + lambda: len(cluster.metadata.all_hosts()) >= expected_total, + timeout_seconds=60, + description="%d nodes in metadata" % expected_total, + ) + + for node_id in new_node_ids: + nlb.add_node(node_id, "127.0.0.%d" % node_id) + + all_host_ids = get_host_ids_from_cluster(session) + log.info("All host IDs after expansion: %s", all_host_ids) + post_routes_for_nlb("127.0.0.1", self.connection_id, all_host_ids, nlb) + + handler.initialize( + cluster.control_connection._connection, + cluster.control_connection._timeout) + + self._wait_for_condition( + lambda: sum(1 for h in cluster.metadata.all_hosts() if h.is_up) >= expected_total, + timeout_seconds=60, + description="all %d nodes up" % expected_total, + ) + + self._assert_query_works(session) + + all_node_ids = set(original_node_ids) | set(new_node_ids) + assert_routes_via_nlb(self, cluster, nlb, all_node_ids) + log.info("Stage 3: All %d nodes via NLB after expansion", + len(all_node_ids)) + + # ---- Stage 4: Decommission original nodes ---- + log.info("Stage 4: Decommissioning original nodes %s", original_node_ids) + + remaining_node_ids = set(all_node_ids) + remaining_host_ids = dict(all_host_ids) + for node_id in original_node_ids: + log.info("Decommissioning node %d", node_id) + get_node(node_id).decommission() + nlb.remove_node(node_id) + remaining_node_ids.discard(node_id) + + ip = "127.0.0.%d" % node_id + remaining_host_ids.pop(ip, None) + + surviving_ips = list(remaining_host_ids.keys()) + if surviving_ips: + post_routes_for_nlb( + surviving_ips[0], self.connection_id, + remaining_host_ids, nlb, + ) + + expected_remaining = expected_total - (original_node_ids.index(node_id) + 1) + self._wait_for_condition( + lambda er=expected_remaining: ( + len(cluster.metadata.all_hosts()) <= er + and self._query_succeeds(session) + ), + timeout_seconds=60, + description="node %d decommissioned" % node_id, + ) + + # Reload routes after the control connection has + # re-established itself (the decommission may have + # killed the old control connection). + handler.initialize( + cluster.control_connection._connection, + cluster.control_connection._timeout) + + assert_routes_via_nlb(self, cluster, nlb, + remaining_node_ids) + log.info("Node %d decommissioned, %d nodes still via NLB", + node_id, len(remaining_node_ids)) + + # ---- Stage 5: Verify with only new nodes ---- + log.info("Stage 5: Verifying session works with only new nodes %s", new_node_ids) + self._assert_query_works(session) + + hosts = cluster.metadata.all_hosts() + self.assertEqual( + len(hosts), len(new_node_ids), + "Expected %d hosts, got %d" % (len(new_node_ids), len(hosts)) + ) + + for _ in range(10): + self._assert_query_works(session) + + assert_routes_via_nlb(self, cluster, nlb, new_node_ids) + log.info("PASS: Full node replacement, all %d new nodes via NLB", + len(new_node_ids)) + + def _assert_query_works(self, session): + rs = session.execute("SELECT release_version FROM system.local WHERE key='local'") + row = rs.one() + self.assertIsNotNone(row, "Query via NLB should return a result") + + def _query_succeeds(self, session): + try: + self._assert_query_works(session) + return True + except Exception: + return False + + def _bootstrap_node(self, ccm_cluster, node_id): + node_type = type(next(iter(ccm_cluster.nodes.values()))) + ip = "127.0.0.%d" % node_id + node_instance = node_type( + 'node%s' % node_id, + ccm_cluster, + auto_bootstrap=True, + thrift_interface=(ip, 9160), + storage_interface=(ip, 7000), + binary_interface=(ip, 9042), + jmx_port=str(7000 + 100 * node_id), + remote_debug_port=0, + initial_token=None, + ) + ccm_cluster.add(node_instance, is_seed=False) + node_instance.start(wait_for_binary_proto=True, wait_other_notice=True) + wait_for_node_socket(node_instance, 120) + log.info("Node %d bootstrapped successfully", node_id) + + @staticmethod + def _wait_for_condition(predicate, timeout_seconds, poll_interval=2, description="condition"): + deadline = time.time() + timeout_seconds + while time.time() < deadline: + if predicate(): + return True + time.sleep(poll_interval) + raise AssertionError( + "Timed out waiting for %s after %d seconds" % (description, timeout_seconds) + ) diff --git a/tests/unit/cython/test_parse_desc_cache.py b/tests/unit/cython/test_parse_desc_cache.py new file mode 100644 index 0000000000..256d98656e --- /dev/null +++ b/tests/unit/cython/test_parse_desc_cache.py @@ -0,0 +1,207 @@ +# Copyright ScyllaDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for the ParseDesc cache in row_parser.pyx. + +Validates cache hit/miss behavior, protocol_version invalidation, cache +clearing, and bounded eviction — all exercised through the actual Cython +_get_or_build_parse_desc function via make_recv_results_rows(). +""" + +import io +import struct +import unittest + +from tests.unit.cython.utils import cythontest + +try: + from cassandra.row_parser import ( + clear_parse_desc_cache, + get_parse_desc_cache_size, + make_recv_results_rows, + ) + from cassandra.obj_parser import ListParser + + _HAS_ROW_PARSER = True + _recv_results_rows = make_recv_results_rows(ListParser()) +except ImportError: + _HAS_ROW_PARSER = False + _recv_results_rows = None + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _build_column_metadata(ncols): + """Build a column_metadata list like the driver produces.""" + from cassandra import cqltypes + + return [("ks", "tbl", "col_%d" % i, cqltypes.UTF8Type) for i in range(ncols)] + + +# NO_METADATA_FLAG as defined in ResultMessage +_NO_METADATA_FLAG = 0x0004 + + +class _MockResultMessage: + """Minimal mock of ResultMessage for the prepared-statement path.""" + + column_metadata = None + column_names = None + column_types = None + parsed_rows = None + paging_state = None + continuous_paging_seq = None + continuous_paging_last = None + result_metadata_id = None + + def recv_results_metadata(self, f, user_type_map): + """Simulate the prepared-statement path (NO_METADATA_FLAG is set).""" + _flags = struct.unpack(">i", f.read(4))[0] + _colcount = struct.unpack(">i", f.read(4))[0] + + +def _build_binary_buf(nrows, ncols, col_value=b"hello world"): + """Build a full binary buffer for the prepared-statement path.""" + parts = [] + parts.append(struct.pack(">i", _NO_METADATA_FLAG)) + parts.append(struct.pack(">i", ncols)) + parts.append(struct.pack(">i", nrows)) + col_cell = struct.pack(">i", len(col_value)) + col_value + row_data = col_cell * ncols + for _ in range(nrows): + parts.append(row_data) + return b"".join(parts) + + +def _recv(binary_buf, col_meta, protocol_version=4, ce_policy=None): + """Run recv_results_rows and return the MockResultMessage.""" + msg = _MockResultMessage() + _recv_results_rows( + msg, io.BytesIO(binary_buf), protocol_version, {}, col_meta, ce_policy + ) + return msg + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class ParseDescCacheTest(unittest.TestCase): + """Tests for the Cython ParseDesc cache in row_parser.pyx.""" + + def setUp(self): + if _HAS_ROW_PARSER: + clear_parse_desc_cache() + + def tearDown(self): + if _HAS_ROW_PARSER: + clear_parse_desc_cache() + + @cythontest + def test_cache_hit_returns_same_objects(self): + """Repeated calls with the same col_meta object should return + identical column_names and column_types objects (cache hit).""" + col_meta = _build_column_metadata(5) + buf = _build_binary_buf(1, 5) + + msg1 = _recv(buf, col_meta) + msg2 = _recv(buf, col_meta) + + self.assertIs(msg1.column_names, msg2.column_names) + self.assertIs(msg1.column_types, msg2.column_types) + + @cythontest + def test_cache_miss_different_metadata(self): + """Different metadata list objects should produce cache misses.""" + buf = _build_binary_buf(1, 5) + col_meta_a = _build_column_metadata(5) + col_meta_b = _build_column_metadata(5) + + msg_a = _recv(buf, col_meta_a) + msg_b = _recv(buf, col_meta_b) + + self.assertIsNot(msg_a.column_names, msg_b.column_names) + self.assertEqual(msg_a.column_names, msg_b.column_names) + + @cythontest + def test_protocol_version_invalidates_cache(self): + """Changed protocol_version should invalidate the cache entry.""" + col_meta = _build_column_metadata(5) + buf = _build_binary_buf(1, 5) + + msg_v4 = _recv(buf, col_meta, protocol_version=4) + msg_v5 = _recv(buf, col_meta, protocol_version=5) + + self.assertIsNot(msg_v4.column_names, msg_v5.column_names) + + @cythontest + def test_clear_cache_invalidates_entries(self): + """clear_parse_desc_cache() should invalidate cached entries.""" + col_meta = _build_column_metadata(5) + buf = _build_binary_buf(1, 5) + + msg1 = _recv(buf, col_meta) + clear_parse_desc_cache() + msg2 = _recv(buf, col_meta) + + self.assertIsNot(msg1.column_names, msg2.column_names) + self.assertEqual(msg1.column_names, msg2.column_names) + + @cythontest + def test_cache_bounded_size(self): + """Cache should evict entries when exceeding the max size (256).""" + buf = _build_binary_buf(1, 5) + meta_lists = [_build_column_metadata(5) for _ in range(300)] + + for meta in meta_lists: + _recv(buf, meta) + + cache_size = get_parse_desc_cache_size() + self.assertLessEqual( + cache_size, + 256, + "Cache should be bounded to 256 entries, got %d" % cache_size, + ) + + @cythontest + def test_parsed_rows_correctness(self): + """Verify parsed row data is correct through the cached path.""" + ncols, nrows = 5, 3 + col_meta = _build_column_metadata(ncols) + buf = _build_binary_buf(nrows, ncols, col_value=b"test_val") + + msg = _recv(buf, col_meta) + + self.assertEqual(len(msg.parsed_rows), nrows) + for row in msg.parsed_rows: + self.assertEqual(len(row), ncols) + for val in row: + self.assertEqual(val, "test_val") + self.assertEqual(msg.column_names, ["col_%d" % i for i in range(ncols)]) + + @cythontest + def test_get_cache_size(self): + """get_parse_desc_cache_size() reports correct count.""" + self.assertEqual(get_parse_desc_cache_size(), 0) + + col_meta = _build_column_metadata(5) + buf = _build_binary_buf(1, 5) + _recv(buf, col_meta) + + self.assertEqual(get_parse_desc_cache_size(), 1) diff --git a/tests/unit/test_client_routes.py b/tests/unit/test_client_routes.py new file mode 100644 index 0000000000..0aa82fc76a --- /dev/null +++ b/tests/unit/test_client_routes.py @@ -0,0 +1,482 @@ +# Copyright 2026 ScyllaDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import socket +import ssl +import unittest +import uuid +from unittest.mock import Mock, patch + +from cassandra.client_routes import ( + ClientRouteProxy, + ClientRoutesChangeType, + ClientRoutesConfig, + _RouteStore, + _Route, + _ClientRoutesHandler +) +from cassandra.connection import ClientRoutesEndPoint, ClientRoutesEndPointFactory +from cassandra.cluster import Cluster + + +class TestClientRouteProxy(unittest.TestCase): + + def test_endpoint_none_connection_id(self): + with self.assertRaises(ValueError): + ClientRouteProxy(None) + + +class TestClientRoutesConfig(unittest.TestCase): + + def test_config_with_proxies(self): + ep1 = ClientRouteProxy(str(uuid.uuid4()), "10.0.0.1") + ep2 = ClientRouteProxy(str(uuid.uuid4()), "10.0.0.2") + config = ClientRoutesConfig([ep1, ep2]) + self.assertEqual(len(config.proxies), 2) + + def test_config_empty_proxies(self): + with self.assertRaises(ValueError): + ClientRoutesConfig([]) + + def test_config_invalid_proxy_type(self): + with self.assertRaises(TypeError): + ClientRoutesConfig(["not-a-proxy"]) + + + +class TestRouteStore(unittest.TestCase): + + def test_get_by_host_id(self): + routes = _RouteStore() + host_id = uuid.uuid4() + route = _Route( + connection_id=str(uuid.uuid4()), + host_id=host_id, + address="example.com", + port=9042, + ) + + routes.update([route]) + + retrieved = routes.get_by_host_id(host_id) + self.assertEqual(retrieved.host_id, host_id) + self.assertEqual(retrieved.address, "example.com") + + def test_merge_routes(self): + routes = _RouteStore() + host_id1 = uuid.uuid4() + host_id2 = uuid.uuid4() + + route1 = _Route( + connection_id=str(uuid.uuid4()), host_id=host_id1, + address="host1.com", port=9042, + ) + + route2 = _Route( + connection_id=str(uuid.uuid4()), host_id=host_id2, + address="host2.com", port=9042, + ) + + routes.update([route1]) + routes.merge([route2], affected_host_ids={host_id2}) + + self.assertIsNotNone(routes.get_by_host_id(host_id1)) + self.assertIsNotNone(routes.get_by_host_id(host_id2)) + + def test_merge_deletes_affected_host_with_no_new_route(self): + """When an affected host_id has no corresponding new route, it should be removed.""" + store = _RouteStore() + host_id1 = uuid.uuid4() + host_id2 = uuid.uuid4() + conn_id = str(uuid.uuid4()) + + store.update([ + _Route(connection_id=conn_id, host_id=host_id1, address="a.com", port=9042), + _Route(connection_id=conn_id, host_id=host_id2, address="b.com", port=9042), + ]) + self.assertIsNotNone(store.get_by_host_id(host_id1)) + self.assertIsNotNone(store.get_by_host_id(host_id2)) + + # Merge with host_id2 affected but no new route for it → deletion + store.merge([], affected_host_ids={host_id2}) + + self.assertIsNotNone(store.get_by_host_id(host_id1)) + self.assertIsNone(store.get_by_host_id(host_id2)) + + def test_select_preferred_routes_keeps_existing_connection_id(self): + """When multiple connection_ids provide routes for the same host_id, + the one already in use should be preferred.""" + store = _RouteStore() + host_id = uuid.uuid4() + conn_a = "conn-a" + conn_b = "conn-b" + + # Populate store with conn_a for host_id + store.update([_Route(connection_id=conn_a, host_id=host_id, address="a.com", port=9042)]) + self.assertEqual(store.get_by_host_id(host_id).connection_id, conn_a) + + # Update with both conn_a and conn_b for the same host_id + store.update([ + _Route(connection_id=conn_b, host_id=host_id, address="b.com", port=9042), + _Route(connection_id=conn_a, host_id=host_id, address="a-new.com", port=9042), + ]) + # conn_a should be preferred since it was already in use + result = store.get_by_host_id(host_id) + self.assertEqual(result.connection_id, conn_a) + self.assertEqual(result.address, "a-new.com") + + def test_select_preferred_routes_falls_back_when_existing_gone(self): + """When the existing connection_id is no longer among candidates, + the first candidate should be selected.""" + store = _RouteStore() + host_id = uuid.uuid4() + + store.update([_Route(connection_id="old-conn", host_id=host_id, address="old.com", port=9042)]) + + # Update only has new connection_ids + store.update([ + _Route(connection_id="new-a", host_id=host_id, address="a.com", port=9042), + _Route(connection_id="new-b", host_id=host_id, address="b.com", port=9042), + ]) + result = store.get_by_host_id(host_id) + self.assertEqual(result.connection_id, "new-a") + + +class TestClientRoutesHandler(unittest.TestCase): + + def setUp(self): + self.conn_id = uuid.uuid4() + self.proxy = ClientRouteProxy(str(self.conn_id), "10.0.0.1") + self.config = ClientRoutesConfig([self.proxy]) + + def test_handler_initialization(self): + handler = _ClientRoutesHandler(self.config, ssl_enabled=False) + self.assertIsNotNone(handler) + self.assertEqual(handler.ssl_enabled, False) + + @patch.object(_ClientRoutesHandler, '_query_all_routes_for_connections') + def test_initialize(self, mock_query): + host_id = uuid.uuid4() + mock_query.return_value = [ + _Route( + connection_id=self.conn_id, + host_id=host_id, + address="node1.example.com", + port=9042, + ) + ] + + handler = _ClientRoutesHandler(self.config) + mock_conn = Mock() + + handler.initialize(mock_conn, timeout=5.0) + + mock_query.assert_called_once() + route = handler._routes.get_by_host_id(host_id) + self.assertIsNotNone(route) + self.assertEqual(route.address, "node1.example.com") + + @patch.object(_ClientRoutesHandler, '_query_routes_for_change_event') + def test_handle_change_filters_by_configured_connection_ids(self, mock_query): + """Events with unrelated connection_ids should be ignored.""" + handler = _ClientRoutesHandler(self.config) + mock_conn = Mock() + host_id = str(uuid.uuid4()) + + # Event with a connection_id NOT in our config → should return early + handler.handle_client_routes_change( + mock_conn, 5.0, + ClientRoutesChangeType.UPDATE_NODES, + connection_ids=["unrelated-conn-id"], + host_ids=[host_id], + ) + mock_query.assert_not_called() + + @patch.object(_ClientRoutesHandler, '_query_routes_for_change_event') + def test_handle_change_merges_when_host_ids_present(self, mock_query): + """When host_ids are provided, routes should be merged (not full replace).""" + handler = _ClientRoutesHandler(self.config) + mock_conn = Mock() + + existing_host = uuid.uuid4() + new_host = uuid.uuid4() + conn_id = str(self.conn_id) + + # Pre-populate a route + handler._routes.update([ + _Route(connection_id=conn_id, host_id=existing_host, address="old.com", port=9042), + ]) + + mock_query.return_value = [ + _Route(connection_id=conn_id, host_id=new_host, address="new.com", port=9042), + ] + + handler.handle_client_routes_change( + mock_conn, 5.0, + ClientRoutesChangeType.UPDATE_NODES, + connection_ids=[conn_id], + host_ids=[str(new_host)], + ) + + # Existing route should still be there (merge, not replace) + self.assertIsNotNone(handler._routes.get_by_host_id(existing_host)) + self.assertIsNotNone(handler._routes.get_by_host_id(new_host)) + + @patch.object(_ClientRoutesHandler, '_query_all_routes_for_connections') + def test_handle_change_updates_when_no_host_ids(self, mock_query): + """When no host_ids are provided, routes should be fully replaced.""" + handler = _ClientRoutesHandler(self.config) + mock_conn = Mock() + conn_id = str(self.conn_id) + + old_host = uuid.uuid4() + handler._routes.update([ + _Route(connection_id=conn_id, host_id=old_host, address="old.com", port=9042), + ]) + + new_host = uuid.uuid4() + mock_query.return_value = [ + _Route(connection_id=conn_id, host_id=new_host, address="new.com", port=9042), + ] + + handler.handle_client_routes_change( + mock_conn, 5.0, + ClientRoutesChangeType.UPDATE_NODES, + connection_ids=None, + host_ids=None, + ) + + # Full replace: old_host gone, new_host present + self.assertIsNone(handler._routes.get_by_host_id(old_host)) + self.assertIsNotNone(handler._routes.get_by_host_id(new_host)) + + @patch.object(_ClientRoutesHandler, '_query_routes_for_change_event') + def test_handle_change_propagates_query_failure(self, mock_query): + """If _query_routes raises, handle_client_routes_change should propagate.""" + handler = _ClientRoutesHandler(self.config) + mock_conn = Mock() + mock_query.side_effect = Exception("network error") + + conn_id = self.proxy.connection_id + host_id = str(uuid.uuid4()) + with self.assertRaises(Exception) as cm: + handler.handle_client_routes_change( + mock_conn, 5.0, + ClientRoutesChangeType.UPDATE_NODES, + connection_ids=[conn_id], + host_ids=[host_id], + ) + self.assertIn("network error", str(cm.exception)) + + @patch.object(_ClientRoutesHandler, '_query_all_routes_for_connections') + def test_initialize_propagates_exception_on_failure(self, mock_query): + """initialize should propagate exceptions to caller.""" + handler = _ClientRoutesHandler(self.config) + mock_conn = Mock() + mock_query.side_effect = Exception("query failed") + + with self.assertRaises(Exception) as ctx: + handler.initialize(mock_conn, 5.0) + self.assertIn("query failed", str(ctx.exception)) + self.assertEqual(mock_query.call_count, 1) + + @patch.object(_ClientRoutesHandler, '_query_all_routes_for_connections') + def test_initialize_keeps_old_routes_on_failure(self, mock_query): + """On failure, existing routes must be preserved (critical for PL clusters).""" + handler = _ClientRoutesHandler(self.config) + mock_conn = Mock() + host_id = uuid.uuid4() + + # Pre-populate a route + handler._routes.update([ + _Route(connection_id=str(self.conn_id), host_id=host_id, address="old.com", port=9042), + ]) + + mock_query.side_effect = Exception("query failed") + with self.assertRaises(Exception): + handler.initialize(mock_conn, 5.0) + + # Old route must still be there + self.assertIsNotNone(handler._routes.get_by_host_id(host_id)) + + @patch.object(_ClientRoutesHandler, '_query_all_routes_for_connections') + def test_initialize_updates_routes_on_success(self, mock_query): + """initialize should update routes on success.""" + handler = _ClientRoutesHandler(self.config) + mock_conn = Mock() + host_id = uuid.uuid4() + + mock_query.return_value = [ + _Route(connection_id=str(self.conn_id), host_id=host_id, address="new.com", port=9042), + ] + + handler.initialize(mock_conn, 5.0) + + self.assertEqual(mock_query.call_count, 1) + route = handler._routes.get_by_host_id(host_id) + self.assertIsNotNone(route) + self.assertEqual(route.address, "new.com") + +class TestClientRoutesEndPoint(unittest.TestCase): + + def setUp(self): + self.conn_id = uuid.uuid4() + self.proxy = ClientRouteProxy(str(self.conn_id), "10.0.0.1") + self.config = ClientRoutesConfig([self.proxy]) + self.handler = _ClientRoutesHandler(self.config, ssl_enabled=False) + + def test_resolve_falls_back_when_no_mapping(self): + """resolve() should return original address/port when no route mapping exists.""" + host_id = uuid.uuid4() + ep = ClientRoutesEndPoint( + host_id=host_id, + handler=self.handler, + original_address="10.0.0.1", + original_port=9042, + ) + self.assertEqual(ep.resolve(), ("10.0.0.1", 9042)) + + @patch('cassandra.client_routes.socket.getaddrinfo', + return_value=[(socket.AF_INET, socket.SOCK_STREAM, 0, '', ("192.168.1.100", 9042))]) + def test_resolve_returns_address_when_route_exists(self, _mock_getaddrinfo): + """resolve() should return the DNS-resolved address and port when a route exists.""" + host_id = uuid.uuid4() + self.handler._routes.update([ + _Route(connection_id=str(self.conn_id), host_id=host_id, + address="nlb.example.com", port=9042), + ]) + ep = ClientRoutesEndPoint( + host_id=host_id, + handler=self.handler, + original_address="10.0.0.1", + original_port=9042, + ) + self.assertEqual(ep.resolve(), ("192.168.1.100", 9042)) + _mock_getaddrinfo.assert_called_once_with( + "nlb.example.com", 9042, socket.AF_UNSPEC, socket.SOCK_STREAM) + + @patch('cassandra.client_routes.socket.getaddrinfo', + side_effect=socket.gaierror("DNS resolution failed")) + def test_resolve_host_dns_failure_raises(self, _mock_getaddrinfo): + """resolve_host should propagate socket.gaierror on DNS failure.""" + host_id = uuid.uuid4() + self.handler._routes.update([ + _Route(connection_id=str(self.conn_id), host_id=host_id, + address="nonexistent.example.com", port=9042), + ]) + with self.assertRaises(socket.gaierror): + self.handler.resolve_host(host_id) + + def test_resolve_host_missing_port_raises(self): + """resolve_host should raise ValueError when route has no port.""" + host_id = uuid.uuid4() + self.handler._routes.update([ + _Route(connection_id=str(self.conn_id), host_id=host_id, + address="host.com", port=0), + ]) + with self.assertRaises(ValueError): + self.handler.resolve_host(host_id) + + +class TestClientRoutesEndPointFactory(unittest.TestCase): + + def setUp(self): + self.conn_id = uuid.uuid4() + proxy = ClientRouteProxy(str(self.conn_id), "10.0.0.1") + self.config = ClientRoutesConfig([proxy]) + self.handler = _ClientRoutesHandler(self.config, ssl_enabled=False) + self.factory = ClientRoutesEndPointFactory(self.handler, default_port=9042) + + def test_create_from_row(self): + """Factory should create a ClientRoutesEndPoint from a peers row.""" + host_id = uuid.uuid4() + row = { + "host_id": host_id, + "rpc_address": "10.0.0.5", + "native_transport_port": 9042, + "peer": "10.0.0.5", + } + ep = self.factory.create(row) + self.assertIsInstance(ep, ClientRoutesEndPoint) + self.assertEqual(ep.host_id, host_id) + self.assertEqual(ep.address, "10.0.0.5") + + def test_create_missing_host_id_raises(self): + """Factory should raise ValueError when row has no host_id.""" + row = {"rpc_address": "10.0.0.5", "native_transport_port": 9042} + with self.assertRaises(ValueError): + self.factory.create(row) + +class TestClientRoutesSSLValidation(unittest.TestCase): + + def test_check_hostname_with_ssl_context_raises(self): + """Cluster should reject check_hostname=True with client_routes_config.""" + ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + self.assertTrue(ssl_ctx.check_hostname) + + config = ClientRoutesConfig( + proxies=[ClientRouteProxy(str(uuid.uuid4()), "10.0.0.1")] + ) + with self.assertRaises(ValueError) as cm: + Cluster( + contact_points=["10.0.0.1"], + ssl_context=ssl_ctx, + client_routes_config=config, + ) + self.assertIn("check_hostname", str(cm.exception)) + + def test_check_hostname_with_ssl_options_raises(self): + """Cluster should reject check_hostname=True in ssl_options with client_routes_config.""" + config = ClientRoutesConfig( + proxies=[ClientRouteProxy(str(uuid.uuid4()), "10.0.0.1")] + ) + with self.assertRaises(ValueError) as cm: + Cluster( + contact_points=["10.0.0.1"], + ssl_options={'check_hostname': True}, + client_routes_config=config, + ) + self.assertIn("check_hostname", str(cm.exception)) + + def test_disabled_check_hostname_with_client_routes_ok(self): + """Cluster should allow check_hostname=False with client_routes_config.""" + ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ssl_ctx.check_hostname = False + + config = ClientRoutesConfig( + proxies=[ClientRouteProxy(str(uuid.uuid4()), "10.0.0.1")] + ) + # Should not raise + cluster = Cluster( + contact_points=["10.0.0.1"], + ssl_context=ssl_ctx, + client_routes_config=config, + ) + cluster.shutdown() + + def test_no_ssl_with_client_routes_ok(self): + """Cluster should allow client_routes_config without SSL.""" + config = ClientRoutesConfig( + proxies=[ClientRouteProxy(str(uuid.uuid4()), "10.0.0.1")] + ) + # Should not raise + cluster = Cluster( + contact_points=["10.0.0.1"], + client_routes_config=config, + ) + cluster.shutdown() + + +if __name__ == '__main__': + unittest.main()