diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 5e7a68bc1c..fb5020f70f 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -56,6 +56,9 @@ from cassandra.cqltypes import UserType import cassandra.cqltypes as types from cassandra.encoder import Encoder +from cassandra.events import (_EventBus, DriverEvent, HOST, HOST_ADDED, + HOST_REMOVED, HOST_UP, HOST_DOWN, + HOST_CHANGED, HostEventPayload) from cassandra.protocol import (QueryMessage, ResultMessage, ErrorMessage, ReadTimeoutErrorMessage, WriteTimeoutErrorMessage, @@ -505,8 +508,26 @@ def __init__(self, load_balancing_policy=None, retry_policy=None, class ProfileManager(object): - def __init__(self): + _EVENT_TYPES = (HOST_CHANGED,) + + def __init__(self, event_bus=None): self.profiles = dict() + self._event_bus = event_bus + if event_bus: + for event_type in self._EVENT_TYPES: + event_bus.subscribe(event_type, self._handle_host_event) + + def _handle_host_event(self, event): + payload = event.payload + if event.type == HOST_CHANGED: + if payload.old_host is not payload.new_host: + self.on_change(payload.old_host, payload.new_host, payload.changed_fields) + + def shutdown(self): + if self._event_bus: + for event_type in self._EVENT_TYPES: + self._event_bus.unsubscribe(event_type, self._handle_host_event) + self._event_bus = None def _profiles_without_explicit_lbps(self): names = (profile_name for @@ -548,6 +569,10 @@ def on_remove(self, host): for p in self.profiles.values(): p.load_balancing_policy.on_remove(host) + def on_change(self, old_host, new_host, changed_fields): + for p in self.profiles.values(): + p.load_balancing_policy.on_change(old_host, new_host, changed_fields) + @property def default(self): """ @@ -556,6 +581,44 @@ def default(self): return self.profiles[EXEC_PROFILE_DEFAULT] +class _HostStateListenerAdapter(object): + + def __init__(self, cluster, event_bus): + self._cluster = weakref.proxy(cluster) + self._event_bus = event_bus + event_bus.subscribe(HOST_UP, self._handle_host_event) + event_bus.subscribe(HOST_DOWN, self._handle_host_event) + event_bus.subscribe(HOST_ADDED, self._handle_host_event) + event_bus.subscribe(HOST_REMOVED, self._handle_host_event) + + def _handle_host_event(self, event): + try: + listeners = self._cluster.listeners + except ReferenceError: + self.shutdown() + return + + payload = event.payload + host = payload.host + for listener in listeners: + if event.type == HOST_UP: + listener.on_up(host) + elif event.type == HOST_DOWN: + listener.on_down(host) + elif event.type == HOST_ADDED: + listener.on_add(host) + elif event.type == HOST_REMOVED: + listener.on_remove(host) + + def shutdown(self): + if self._event_bus: + self._event_bus.unsubscribe(HOST_UP, self._handle_host_event) + self._event_bus.unsubscribe(HOST_DOWN, self._handle_host_event) + self._event_bus.unsubscribe(HOST_ADDED, self._handle_host_event) + self._event_bus.unsubscribe(HOST_REMOVED, self._handle_host_event) + self._event_bus = None + + EXEC_PROFILE_DEFAULT = object() """ Key for the ``Cluster`` default execution profile, used when no other profile is selected in @@ -1395,7 +1458,8 @@ def __init__(self, else: self.timestamp_generator = MonotonicTimestampGenerator() - self.profile_manager = ProfileManager() + self._event_bus = _EventBus() + self.profile_manager = ProfileManager(self._event_bus) self.profile_manager.profiles[EXEC_PROFILE_DEFAULT] = ExecutionProfile( self.load_balancing_policy, self.default_retry_policy, @@ -1483,11 +1547,12 @@ def __init__(self, self._listeners = set() self._listener_lock = Lock() + self._host_listener_adapter = None # let Session objects be GC'ed (and shutdown) when the user no longer # holds a reference. self.sessions = WeakSet() - self.metadata = Metadata() + self.metadata = Metadata(self._event_bus) self.control_connection = None self._prepared_statements = WeakValueDictionary() self._prepared_statement_lock = Lock() @@ -1509,6 +1574,7 @@ def __init__(self, self.status_event_refresh_window, schema_metadata_enabled, token_metadata_enabled, schema_meta_page_size=schema_metadata_page_size) + self._host_listener_adapter = _HostStateListenerAdapter(self, self._event_bus) if client_id is None: self.client_id = uuid.uuid4() @@ -1838,6 +1904,13 @@ def shutdown(self): for session in tuple(self.sessions): session.shutdown() + if self._host_listener_adapter: + self._host_listener_adapter.shutdown() + self._host_listener_adapter = None + + if self.profile_manager: + self.profile_manager.shutdown() + self.executor.shutdown() if self.metrics_enabled and self.metrics: @@ -1862,6 +1935,19 @@ def _session_register_user_types(self, session): for udt_name, klass in type_map.items(): session.user_type_registered(keyspace, udt_name, klass) + def _publish_host_event(self, event_type, host=None, old_host=None, + new_host=None, changed_fields=(), old_values=None, + new_values=None, refresh_nodes=True, source=None): + payload = HostEventPayload( + host=host, + old_host=old_host, + new_host=new_host, + changed_fields=changed_fields, + old_values=old_values, + new_values=new_values, + refresh_nodes=refresh_nodes) + return self._event_bus.publish(DriverEvent(event_type, HOST, payload, source or self)) + def _cleanup_failed_on_up_handling(self, host): self.profile_manager.on_down(host) self.control_connection.on_down(host) @@ -1897,8 +1983,7 @@ def _on_up_future_completed(self, host, futures, results, lock, finished_future) log.info("Connection pools established for node %s", host) # mark the host as up and notify all listeners host.set_up() - for listener in self.listeners: - listener.on_up(host) + self._publish_host_event(HOST_UP, host=host) finally: with host.lock: host._currently_handling_node_up = False @@ -1975,6 +2060,7 @@ def on_up(self, host): with host.lock: host.set_up() host._currently_handling_node_up = False + self._publish_host_event(HOST_UP, host=host) # for testing purposes return futures @@ -2007,11 +2093,7 @@ def _start_reconnector(self, host, is_host_addition): def on_down_potentially_blocking(self, host, is_host_addition): self.profile_manager.on_down(host) self.control_connection.on_down(host) - for session in tuple(self.sessions): - session.on_down(host) - - for listener in self.listeners: - listener.on_down(host) + self._publish_host_event(HOST_DOWN, host=host) self._start_reconnector(host, is_host_addition) @@ -2061,7 +2143,7 @@ def on_add(self, host, refresh_nodes=True): if distance == HostDistance.IGNORED: log.debug("Not adding connection pool for new host %r because the " "load balancing policy has marked it as IGNORED", host) - self._finalize_add(host, set_up=False) + self._finalize_add(host, set_up=False, refresh_nodes=refresh_nodes) return futures_lock = Lock() @@ -2090,7 +2172,7 @@ def future_completed(future): log.warning("Connection pool could not be created, not marking node %s up", host) return - self._finalize_add(host) + self._finalize_add(host, refresh_nodes=refresh_nodes) have_future = False for session in tuple(self.sessions): @@ -2101,31 +2183,28 @@ def future_completed(future): future.add_done_callback(future_completed) if not have_future: - self._finalize_add(host) + self._finalize_add(host, refresh_nodes=refresh_nodes) - def _finalize_add(self, host, set_up=True): + def _finalize_add(self, host, set_up=True, refresh_nodes=True): if set_up: host.set_up() - for listener in self.listeners: - listener.on_add(host) - # see if there are any pools to add or remove now that the host is marked up for session in tuple(self.sessions): session.update_created_pools() - def on_remove(self, host): + self._publish_host_event(HOST_ADDED, host=host, refresh_nodes=refresh_nodes) + + def on_remove(self, host, source=None): if self.is_shutdown: return log.debug("[cluster] Removing host %s", host) host.set_down() self.profile_manager.on_remove(host) - for session in tuple(self.sessions): - session.on_remove(host) - for listener in self.listeners: - listener.on_remove(host) - self.control_connection.on_remove(host) + if source is not self.control_connection: + self.control_connection.on_remove(host) + self._publish_host_event(HOST_REMOVED, host=host, source=source) reconnection_handler = host.get_and_set_reconnection_handler(None) if reconnection_handler: @@ -2148,21 +2227,23 @@ def add_host(self, endpoint, datacenter=None, rack=None, signal=True, refresh_no with self.metadata._hosts_lock: if endpoint in self.metadata._host_id_by_endpoint: return self.metadata._hosts[self.metadata._host_id_by_endpoint[endpoint]], False - host, new = self.metadata.add_or_return_host(Host(endpoint, self.conviction_policy_factory, datacenter, rack, host_id=host_id)) + host, new = self.metadata.add_or_return_host( + Host(endpoint, self.conviction_policy_factory, datacenter, rack, + host_id=host_id, event_bus=self._event_bus)) if new and signal: log.info("New Cassandra host %r discovered", host) self.on_add(host, refresh_nodes) return host, new - def remove_host(self, host): + def remove_host(self, host, source=None): """ Called when the control connection observes that a node has left the ring. Intended for internal use only. """ if host and self.metadata.remove_host(host): log.info("Cassandra host %s removed", host) - self.on_remove(host) + self.on_remove(host, source=source) def register_listener(self, listener): """ @@ -2387,6 +2468,34 @@ def add_prepared(self, query_id, prepared_statement): with self._prepared_statement_lock: self._prepared_statements[query_id] = prepared_statement + +class _SessionHostEventHandler(object): + + _EVENT_TYPES = (HOST_DOWN, HOST_REMOVED, HOST_CHANGED) + + def __init__(self, session, event_bus): + self._session_ref = weakref.ref(session, self._session_finalized) + self._event_bus = event_bus + for event_type in self._EVENT_TYPES: + event_bus.subscribe(event_type, self) + + def _session_finalized(self, session_ref): + self.shutdown() + + def __call__(self, event): + session = self._session_ref() + if session is None: + self.shutdown() + return + session._handle_host_event(event) + + def shutdown(self): + if self._event_bus: + for event_type in self._EVENT_TYPES: + self._event_bus.unsubscribe(event_type, self) + self._event_bus = None + + class Session(object): """ A collection of connection pools for each host in the cluster. @@ -2606,6 +2715,7 @@ def default_serial_consistency_level(self, cl): _profile_manager = None _metrics = None _request_init_callbacks = None + _host_event_handler = None _graph_paging_available = False def __init__(self, cluster, hosts, keyspace=None): @@ -2639,6 +2749,8 @@ def __init__(self, cluster, hosts, keyspace=None): msg += " using keyspace '%s'" % self.keyspace raise NoHostAvailable(msg, [h.address for h in hosts]) + self._host_event_handler = _SessionHostEventHandler(self, self.cluster._event_bus) + self.session_id = uuid.uuid4() if self.cluster.column_encryption_policy is not None: @@ -3165,7 +3277,8 @@ def prepare_on_all_hosts(self, query, excluded_host, keyspace=None): Intended for internal use only. """ futures = [] - for host in tuple(self._pools.keys()): + for pool in tuple(self._pools.values()): + host = pool.host if host != excluded_host and host.is_up: future = ResponseFuture(self, PrepareMessage(query=query, keyspace=keyspace), None, self.default_timeout) @@ -3217,6 +3330,10 @@ def shutdown(self): for pool in tuple(self._pools.values()): pool.shutdown() + if self._host_event_handler: + self._host_event_handler.shutdown() + self._host_event_handler = None + def __enter__(self): return self @@ -3256,7 +3373,7 @@ def run_add_or_renew_pool(): host, conn_exc, is_host_addition, expect_host_to_be_down=True) return False - previous = self._pools.get(host) + previous = self._pools.get(host.host_id) with self._lock: while new_pool._keyspace != self.keyspace: self._lock.release() @@ -3276,7 +3393,7 @@ def callback(pool, errors): self._lock.acquire() return False self._lock.acquire() - self._pools[host] = new_pool + self._pools[host.host_id] = new_pool log.debug("Added pool for host %s to session", host) if previous: @@ -3287,7 +3404,7 @@ def callback(pool, errors): return self.submit(run_add_or_renew_pool) def remove_pool(self, host): - pool = self._pools.pop(host, None) + pool = self._pools.pop(host.host_id, None) if pool: log.debug("Removed connection pool for %r", host) return self.submit(pool.shutdown) @@ -3309,7 +3426,7 @@ def update_created_pools(self): futures = set() for host in self.cluster.metadata.all_hosts(): distance = self._profile_manager.distance(host) - pool = self._pools.get(host) + pool = self._pools.get(host.host_id) future = None if not pool or pool.is_shutdown: # we don't eagerly set is_up on previously ignored hosts. None is included here @@ -3340,6 +3457,30 @@ def on_remove(self, host): """ Internal """ self.on_down(host) + def on_change(self, old_host, new_host, changed_fields): + """ Internal """ + pool = self._pools.get(new_host.host_id) + if not pool: + self.update_created_pools() + return None + + if "endpoint" in changed_fields: + return self.add_or_renew_pool(new_host, is_host_addition=False) + + pool.rebind_host(new_host) + self.update_created_pools() + return None + + def _handle_host_event(self, event): + payload = event.payload + if event.type == HOST_DOWN: + self.on_down(payload.host) + elif event.type == HOST_REMOVED: + self.on_remove(payload.host) + elif event.type == HOST_CHANGED: + if payload.old_host is not payload.new_host: + self.on_change(payload.old_host, payload.new_host, payload.changed_fields) + def set_keyspace(self, keyspace): """ Set the default keyspace for all queries made through this Session. @@ -3408,7 +3549,10 @@ def submit(self, fn, *args, **kwargs): return self.cluster.executor.submit(fn, *args, **kwargs) def get_pool_state(self): - return dict((host, pool.get_state()) for host, pool in tuple(self._pools.items())) + return dict((pool.host, pool.get_state()) for pool in tuple(self._pools.values())) + + def get_pool_state_by_host_id(self): + return dict((host_id, pool.get_state()) for host_id, pool in tuple(self._pools.items())) def get_pools(self): return self._pools.values() @@ -3480,6 +3624,8 @@ class ControlConnection(object): Internal """ + _HOST_EVENT_TYPES = (HOST_CHANGED,) + _SELECT_PEERS = "SELECT peer, data_center, host_id, rack, release_version, rpc_address, schema_version, tokens FROM system.peers" _SELECT_PEERS_NO_TOKENS_TEMPLATE = "SELECT host_id, peer, data_center, rack, rpc_address, {nt_col_name}, release_version, schema_version FROM system.peers" _SELECT_LOCAL = "SELECT broadcast_address, cluster_name, data_center, host_id, listen_address, partitioner, rack, release_version, rpc_address, schema_version, tokens FROM system.local WHERE key='local'" @@ -3547,6 +3693,15 @@ def __init__(self, cluster, timeout, self._reconnection_lock = RLock() self._event_schedule_times = {} + self._event_bus = getattr(cluster, '_event_bus', None) + if self._event_bus: + for event_type in self._HOST_EVENT_TYPES: + self._event_bus.subscribe(event_type, self._handle_host_event) + + def _handle_host_event(self, event): + payload = event.payload + if event.type == HOST_CHANGED: + self.on_change(payload.old_host, payload.new_host, payload.changed_fields) def connect(self): if self._is_shutdown: @@ -3770,6 +3925,10 @@ def shutdown(self): if self._connection: self._connection.close() self._connection = None + if self._event_bus: + for event_type in self._HOST_EVENT_TYPES: + self._event_bus.unsubscribe(event_type, self._handle_host_event) + self._event_bus = None def refresh_schema(self, force=False, **kwargs): try: @@ -3884,6 +4043,19 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, host = self._cluster.metadata.get_host(endpoint) datacenter = row.get("data_center") rack = row.get("rack") + host_fields = { + "endpoint": endpoint, + "datacenter": datacenter, + "rack": rack, + "broadcast_address": _NodeInfo.get_broadcast_address(row), + "broadcast_port": _NodeInfo.get_broadcast_port(row), + "broadcast_rpc_address": _NodeInfo.get_broadcast_rpc_address(row), + "broadcast_rpc_port": _NodeInfo.get_broadcast_rpc_port(row), + "release_version": row.get("release_version"), + "dse_version": row.get("dse_version"), + "dse_workload": row.get("workload"), + "dse_workloads": row.get("workloads"), + } if host is None: host = self._cluster.metadata.get_host_by_host_id(host_id) @@ -3892,40 +4064,36 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, reconnector = host.get_and_set_reconnection_handler(None) if reconnector: reconnector.cancel() - self._cluster.on_down(host, is_host_addition=False, expect_host_to_be_down=True) - - old_endpoint = host.endpoint - host.endpoint = endpoint - self._cluster.metadata.update_host(host, old_endpoint) - self._cluster.on_up(host) if host is None: log.debug("[control connection] Found new host to connect to: %s", endpoint) host, _ = self._cluster.add_host(endpoint, datacenter=datacenter, rack=rack, signal=True, refresh_nodes=False, host_id=host_id) should_rebuild_token_map = True + + replace_host = getattr(self._cluster.metadata, "replace_host", None) + if replace_host: + host, changed_fields = replace_host(host_id, source=self, **host_fields) + should_rebuild_token_map |= bool(changed_fields) else: - should_rebuild_token_map |= self._update_location_info(host, datacenter, rack) - - host.host_id = host_id - host.broadcast_address = _NodeInfo.get_broadcast_address(row) - host.broadcast_port = _NodeInfo.get_broadcast_port(row) - host.broadcast_rpc_address = _NodeInfo.get_broadcast_rpc_address(row) - host.broadcast_rpc_port = _NodeInfo.get_broadcast_rpc_port(row) - host.release_version = row.get("release_version") - host.dse_version = row.get("dse_version") - host.dse_workload = row.get("workload") - host.dse_workloads = row.get("workloads") + old_endpoint = host.endpoint + for field, value in host_fields.items(): + if field == "datacenter": + object.__setattr__(host, "_datacenter", value) + elif field == "rack": + object.__setattr__(host, "_rack", value) + else: + object.__setattr__(host, field, value) + self._cluster.metadata.update_host(host, old_endpoint=old_endpoint) tokens = row.get("tokens", None) if partitioner and tokens and self._token_meta_enabled: token_map[host] = tokens - self._cluster.metadata.update_host(host, old_endpoint=endpoint) for old_host_id, old_host in self._cluster.metadata.all_hosts_items(): if old_host_id not in found_host_ids: should_rebuild_token_map = True log.debug("[control connection] Removing host not found in peers metadata: %r", old_host) - self._cluster.metadata.remove_host_by_host_id(old_host_id, old_host.endpoint) + self._cluster.remove_host(old_host, source=self) log.debug("[control connection] Finished fetching ring info") if partitioner and should_rebuild_token_map: @@ -3973,12 +4141,13 @@ def _update_location_info(self, host, datacenter, rack): if host.datacenter == datacenter and host.rack == rack: return False - # If the dc/rack information changes, we need to update the load balancing policy. - # For that, we remove and re-add the node against the policy. Not the most elegant, and assumes - # that the policy will update correctly, but in practice this should work. - self._cluster.profile_manager.on_down(host) - host.set_location_info(datacenter, rack) - self._cluster.profile_manager.on_up(host) + replace_host = getattr(self._cluster.metadata, "replace_host", None) + if replace_host: + replace_host(host.host_id, source=self, datacenter=datacenter, rack=rack) + return True + + new_host = host.set_location_info(datacenter, rack) + self._cluster.metadata.update_host(new_host, old_endpoint=host.endpoint) return True def _delay_for_event_type(self, event_type, delay_window): @@ -4260,6 +4429,16 @@ def on_remove(self, host): else: self.refresh_node_list_and_token_map(force_token_rebuild=True) + def on_change(self, old_host, new_host, changed_fields): + if "endpoint" not in changed_fields: + return + + c = self._connection + if c and c.endpoint == old_host.endpoint: + log.debug("[control connection] Control connection host (%s) endpoint changed to %s. Reconnecting", + old_host, new_host.endpoint) + self.reconnect() + def get_connections(self): c = getattr(self, '_connection', None) return [c] if c else [] @@ -4443,6 +4622,10 @@ class ResponseFuture(object): _warned_timeout = False + @staticmethod + def _pool_key(host): + return host.host_id if isinstance(host, Host) else host + def __init__(self, session, message, query, timeout, metrics=None, prepared_statement=None, retry_policy=RetryPolicy(), row_factory=None, load_balancer=None, start_time=None, speculative_execution_plan=None, continuous_paging_state=None, host=None): @@ -4524,7 +4707,7 @@ def _on_timeout(self, _attempts=0): # Capture connection stats before pool.return_connection() can alter state conn_in_flight = self._connection.in_flight - pool = self.session._pools.get(self._current_host) + pool = self.session._pools.get(self._pool_key(self._current_host)) if pool and not pool.is_shutdown: # Do not return the stream ID to the pool yet. We cannot reuse it # because the node might still be processing the query and will @@ -4607,7 +4790,7 @@ def _query(self, host, message=None, cb=None): if message is None: message = self.message - pool = self.session._pools.get(host) + pool = self.session._pools.get(self._pool_key(host)) if not pool: self._errors[host] = ConnectionException("Host has been marked down or removed") return None diff --git a/cassandra/events.py b/cassandra/events.py new file mode 100644 index 0000000000..c901e4c2d7 --- /dev/null +++ b/cassandra/events.py @@ -0,0 +1,150 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Internal driver event primitives. + +This module intentionally does not expose a public subscription API. It is a +small synchronous bus used to decouple driver subsystems that need to react to +shared internal state changes. +""" + +from collections import defaultdict +import logging +from threading import RLock + + +log = logging.getLogger(__name__) + + +HOST = "HOST" + +HOST_ADDED = "HOST_ADDED" +HOST_REMOVED = "HOST_REMOVED" +HOST_UP = "HOST_UP" +HOST_DOWN = "HOST_DOWN" +HOST_CHANGED = "HOST_CHANGED" + + +class DriverEvent(object): + """ + Internal event envelope. + """ + + __slots__ = ("type", "category", "payload", "source") + + def __init__(self, event_type, category, payload=None, source=None): + self.type = event_type + self.category = category + self.payload = payload + self.source = source + + def __repr__(self): + return "%s(type=%r, category=%r, payload=%r, source=%r)" % ( + self.__class__.__name__, self.type, self.category, self.payload, self.source) + + +class HostEventPayload(object): + """ + Payload for host topology and runtime-state events. + """ + + __slots__ = ( + "host_id", "host", "old_host", "new_host", "changed_fields", + "old_values", "new_values", "refresh_nodes") + + def __init__(self, host=None, host_id=None, old_host=None, new_host=None, + changed_fields=(), old_values=None, new_values=None, + refresh_nodes=True): + if host is None: + host = new_host if new_host is not None else old_host + + if host_id is None and host is not None: + host_id = host.host_id + + self.host_id = host_id + self.host = host + self.old_host = old_host + self.new_host = new_host + self.changed_fields = tuple(changed_fields or ()) + self.old_values = old_values or {} + self.new_values = new_values or {} + self.refresh_nodes = refresh_nodes + + def __repr__(self): + return ("%s(host_id=%r, host=%r, old_host=%r, new_host=%r, " + "changed_fields=%r, old_values=%r, new_values=%r)") % ( + self.__class__.__name__, self.host_id, self.host, + self.old_host, self.new_host, self.changed_fields, + self.old_values, self.new_values) + + +class _EventBus(object): + """ + Synchronous internal event bus. + """ + + def __init__(self): + self._type_subscribers = defaultdict(list) + self._category_subscribers = defaultdict(list) + self._lock = RLock() + + def subscribe(self, event_type, handler): + with self._lock: + handlers = self._type_subscribers[event_type] + if handler not in handlers: + handlers.append(handler) + + def unsubscribe(self, event_type, handler): + with self._lock: + self._remove_handler(self._type_subscribers.get(event_type), handler) + + def subscribe_category(self, category, handler): + with self._lock: + handlers = self._category_subscribers[category] + if handler not in handlers: + handlers.append(handler) + + def unsubscribe_category(self, category, handler): + with self._lock: + self._remove_handler(self._category_subscribers.get(category), handler) + + def publish(self, event): + handlers = self._handlers_for_event(event) + for handler in handlers: + try: + handler(event) + except Exception: + log.exception("Error dispatching driver event %s to %r", event.type, handler) + return event + + @staticmethod + def _remove_handler(handlers, handler): + if not handlers: + return + try: + handlers.remove(handler) + except ValueError: + pass + + def _handlers_for_event(self, event): + with self._lock: + raw_handlers = list(self._type_subscribers.get(event.type, ())) + raw_handlers.extend(self._category_subscribers.get(event.category, ())) + + handlers = [] + for handler in raw_handlers: + if handler not in handlers: + handlers.append(handler) + return handlers diff --git a/cassandra/metadata.py b/cassandra/metadata.py index 43399b7152..2aff692c68 100644 --- a/cassandra/metadata.py +++ b/cassandra/metadata.py @@ -37,6 +37,7 @@ from cassandra import SignatureDescriptor, ConsistencyLevel, InvalidRequest, Unauthorized import cassandra.cqltypes as types from cassandra.encoder import Encoder +from cassandra.events import DriverEvent, HOST, HOST_CHANGED, HostEventPayload from cassandra.marshal import varint_unpack from cassandra.protocol import QueryMessage from cassandra.query import dict_factory, bind_params @@ -121,14 +122,22 @@ class Metadata(object): dbaas = False """ A boolean indicating if connected to a DBaaS cluster """ - def __init__(self): + def __init__(self, event_bus=None): self.keyspaces = {} self.dbaas = False self._hosts = {} self._host_id_by_endpoint = {} + self._runtime_states = {} + self._event_bus = event_bus self._hosts_lock = RLock() self._tablets = Tablets({}) + def set_event_bus(self, event_bus): + self._event_bus = event_bus + with self._hosts_lock: + for runtime_state in self._runtime_states.values(): + runtime_state.set_event_bus(event_bus) + def export_schema_as_string(self): """ Returns a string that can be executed as a query in order to recreate @@ -340,6 +349,7 @@ def add_or_return_host(self, host): try: return self._hosts[host.host_id], False except KeyError: + host = self._bind_runtime_state(host) self._host_id_by_endpoint[host.endpoint] = host.host_id self._hosts[host.host_id] = host return host, True @@ -347,14 +357,22 @@ def add_or_return_host(self, host): def remove_host(self, host): self._tablets.drop_tablets_by_host_id(host.host_id) with self._hosts_lock: + current_host = self._hosts.get(host.host_id) self._host_id_by_endpoint.pop(host.endpoint, False) + if current_host is not None: + self._host_id_by_endpoint.pop(current_host.endpoint, False) + self._runtime_states.pop(host.host_id, None) return bool(self._hosts.pop(host.host_id, False)) def remove_host_by_host_id(self, host_id, endpoint=None): self._tablets.drop_tablets_by_host_id(host_id) with self._hosts_lock: - if endpoint and self._host_id_by_endpoint[endpoint] == host_id: + current_host = self._hosts.get(host_id) + if endpoint and self._host_id_by_endpoint.get(endpoint) == host_id: self._host_id_by_endpoint.pop(endpoint, False) + if current_host is not None: + self._host_id_by_endpoint.pop(current_host.endpoint, False) + self._runtime_states.pop(host_id, None) return bool(self._hosts.pop(host_id, False)) def update_host(self, host, old_endpoint): @@ -363,6 +381,65 @@ def update_host(self, host, old_endpoint): self._host_id_by_endpoint.pop(old_endpoint, False) self._host_id_by_endpoint[host.endpoint] = host.host_id + def replace_host(self, host_id, source=None, **fields): + """ + Replace a Host topology snapshot for host_id and publish HOST_CHANGED. + """ + with self._hosts_lock: + old_host = self._hosts.get(host_id) + if old_host is None: + return None, () + + changed_fields = [] + old_values = {} + new_values = {} + + for field, new_value in fields.items(): + old_value = getattr(old_host, field) + if old_value != new_value: + changed_fields.append(field) + old_values[field] = old_value + new_values[field] = new_value + + if not changed_fields: + return old_host, () + + copy_kwargs = dict((field, fields[field]) for field in changed_fields) + new_host = old_host.copy_with(**copy_kwargs) + new_host.runtime_state.set_event_bus(self._event_bus) + new_host.runtime_state.bind_host(new_host) + + self._hosts[host_id] = new_host + if "endpoint" in changed_fields: + if self._host_id_by_endpoint.get(old_host.endpoint) == host_id: + self._host_id_by_endpoint.pop(old_host.endpoint, False) + self._host_id_by_endpoint[new_host.endpoint] = host_id + else: + self._host_id_by_endpoint[new_host.endpoint] = host_id + + payload = HostEventPayload( + host_id=host_id, + old_host=old_host, + new_host=new_host, + changed_fields=tuple(changed_fields), + old_values=old_values, + new_values=new_values) + if self._event_bus: + self._event_bus.publish(DriverEvent(HOST_CHANGED, HOST, payload, source or self)) + return new_host, tuple(changed_fields) + + def _bind_runtime_state(self, host): + runtime_state = self._runtime_states.get(host.host_id) + if runtime_state is None: + runtime_state = host.runtime_state + self._runtime_states[host.host_id] = runtime_state + elif host.runtime_state is not runtime_state: + host = host.copy_with(runtime_state=runtime_state) + + runtime_state.set_event_bus(self._event_bus) + runtime_state.bind_host(host) + return host + def get_host(self, endpoint_or_address, port=None): """ Find a host in the metadata for a specific endpoint. If a string inet address and port are passed, diff --git a/cassandra/policies.py b/cassandra/policies.py index ceb5ebdc45..a23fe2e25b 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -101,6 +101,10 @@ def on_remove(self, host): """ Called when a node is removed from the cluster. """ raise NotImplementedError() + def on_change(self, old_host, new_host, changed_fields): + """ Called when immutable topology metadata for a node is replaced. """ + pass + class LoadBalancingPolicy(HostStateListener): """ @@ -214,6 +218,15 @@ def on_remove(self, host): with self._hosts_lock: self._live_hosts = self._live_hosts.difference((host, )) + def on_change(self, old_host, new_host, changed_fields): + with self._hosts_lock: + if old_host in self._live_hosts: + self._live_hosts = frozenset( + new_host if host == old_host else host + for host in self._live_hosts) + elif new_host.is_up: + self._live_hosts = self._live_hosts.union((new_host, )) + class DCAwareRoundRobinPolicy(LoadBalancingPolicy): """ @@ -324,6 +337,17 @@ def on_add(self, host): def on_remove(self, host): self.on_down(host) + def on_change(self, old_host, new_host, changed_fields): + old_dc = self._dc(old_host) + with self._hosts_lock: + was_live = old_host in self._dc_live_hosts.get(old_dc, ()) + + if was_live: + self.on_down(old_host) + self.on_up(new_host) + elif new_host.is_up: + self.on_up(new_host) + class RackAwareRoundRobinPolicy(LoadBalancingPolicy): """ Similar to :class:`.DCAwareRoundRobinPolicy`, but prefers hosts @@ -449,6 +473,19 @@ def on_add(self, host): def on_remove(self, host): self.on_down(host) + def on_change(self, old_host, new_host, changed_fields): + old_dc = self._dc(old_host) + old_rack = self._rack(old_host) + with self._hosts_lock: + was_live = (old_host in self._live_hosts.get((old_dc, old_rack), ()) or + old_host in self._dc_live_hosts.get(old_dc, ())) + + if was_live: + self.on_down(old_host) + self.on_up(new_host) + elif new_host.is_up: + self.on_up(new_host) + class TokenAwarePolicy(LoadBalancingPolicy): """ A :class:`.LoadBalancingPolicy` wrapper that adds token awareness to @@ -540,6 +577,9 @@ def on_add(self, *args, **kwargs): def on_remove(self, *args, **kwargs): return self._child_policy.on_remove(*args, **kwargs) + def on_change(self, *args, **kwargs): + return self._child_policy.on_change(*args, **kwargs) + class WhiteListRoundRobinPolicy(RoundRobinPolicy): """ @@ -593,6 +633,19 @@ def on_add(self, host): if host.address in self._allowed_hosts_resolved: RoundRobinPolicy.on_add(self, host) + def on_change(self, old_host, new_host, changed_fields): + old_allowed = old_host.address in self._allowed_hosts_resolved + new_allowed = new_host.address in self._allowed_hosts_resolved + with self._hosts_lock: + was_live = old_host in self._live_hosts + + if was_live and new_allowed: + RoundRobinPolicy.on_change(self, old_host, new_host, changed_fields) + elif was_live: + RoundRobinPolicy.on_down(self, old_host) + elif new_allowed and new_host.is_up: + RoundRobinPolicy.on_up(self, new_host) + class HostFilterPolicy(LoadBalancingPolicy): """ @@ -654,6 +707,16 @@ def on_add(self, host, *args, **kwargs): def on_remove(self, host, *args, **kwargs): return self._child_policy.on_remove(host, *args, **kwargs) + def on_change(self, old_host, new_host, changed_fields): + old_allowed = self.predicate(old_host) + new_allowed = self.predicate(new_host) + if old_allowed and new_allowed: + return self._child_policy.on_change(old_host, new_host, changed_fields) + elif old_allowed: + return self._child_policy.on_remove(old_host) + elif new_allowed and new_host.is_up: + return self._child_policy.on_add(new_host) + @property def predicate(self): """ @@ -1322,6 +1385,9 @@ def on_add(self, *args, **kwargs): def on_remove(self, *args, **kwargs): return self._child_policy.on_remove(*args, **kwargs) + def on_change(self, *args, **kwargs): + return self._child_policy.on_change(*args, **kwargs) + class DefaultLoadBalancingPolicy(WrapperPolicy): """ diff --git a/cassandra/pool.py b/cassandra/pool.py index 9e949c342c..cdf2646ddd 100644 --- a/cassandra/pool.py +++ b/cassandra/pool.py @@ -32,11 +32,16 @@ from cassandra import AuthenticationFailed from cassandra.connection import ConnectionException, EndPoint, DefaultEndPoint +from cassandra.events import (DriverEvent, HOST, HOST_CHANGED, + HostEventPayload) from cassandra.policies import HostDistance log = logging.getLogger(__name__) +_NOT_SET = object() + + class NoConnectionsAvailable(Exception): """ All existing connections to a given host are busy, or there are @@ -45,6 +50,76 @@ class NoConnectionsAvailable(Exception): pass +class HostRuntimeState(object): + """ + Mutable runtime state shared by Host topology snapshots for one host_id. + """ + + def __init__(self, conviction_policy_factory, host=None, event_bus=None): + if conviction_policy_factory is None: + raise ValueError("conviction_policy_factory may not be None") + + self.lock = RLock() + self.conviction_policy = conviction_policy_factory(host) + self.is_up = None + self._reconnection_handler = None + self._currently_handling_node_up = False + self.sharding_info = None + self._event_bus = event_bus + self._host = None + if host is not None: + self.bind_host(host) + + def bind_host(self, host): + self._host = host + try: + self.conviction_policy.host = host + except AttributeError: + pass + + def set_event_bus(self, event_bus): + self._event_bus = event_bus + + def set_up(self, host): + if not self.is_up: + log.debug("Host %s is now marked up", host.endpoint) + self.conviction_policy.reset() + self.is_up = True + + def set_down(self): + self.is_up = False + + def signal_connection_failure(self, connection_exc): + return self.conviction_policy.add_failure(connection_exc) + + def is_currently_reconnecting(self): + return self._reconnection_handler is not None + + def get_and_set_reconnection_handler(self, new_handler): + with self.lock: + old = self._reconnection_handler + self._reconnection_handler = new_handler + return old + + def set_sharding_info(self, host, sharding_info, source=None): + with self.lock: + old_sharding_info = self.sharding_info + if old_sharding_info == sharding_info: + return False + self.sharding_info = sharding_info + + if self._event_bus: + payload = HostEventPayload( + host=host, + old_host=host, + new_host=host, + changed_fields=("sharding_info",), + old_values={"sharding_info": old_sharding_info}, + new_values={"sharding_info": sharding_info}) + self._event_bus.publish(DriverEvent(HOST_CHANGED, HOST, payload, source or host)) + return True + + @total_ordering class Host(object): """ @@ -160,26 +235,53 @@ class Host(object): _datacenter = None _rack = None - _reconnection_handler = None - lock = None - _currently_handling_node_up = False - - sharding_info = None - - def __init__(self, endpoint, conviction_policy_factory, datacenter=None, rack=None, host_id=None): + _IMMUTABLE_FIELDS = frozenset(( + "endpoint", "host_id", "_datacenter", "_rack", "broadcast_address", + "broadcast_port", "broadcast_rpc_address", "broadcast_rpc_port", + "listen_address", "listen_port", "release_version", "dse_version", + "dse_workload", "dse_workloads")) + + def __init__(self, endpoint, conviction_policy_factory, datacenter=None, rack=None, + host_id=None, broadcast_address=None, broadcast_port=None, + broadcast_rpc_address=None, broadcast_rpc_port=None, + listen_address=None, listen_port=None, release_version=None, + dse_version=None, dse_workload=None, dse_workloads=None, + runtime_state=None, event_bus=None): if endpoint is None: raise ValueError("endpoint may not be None") - if conviction_policy_factory is None: + if conviction_policy_factory is None and runtime_state is None: raise ValueError("conviction_policy_factory may not be None") - - self.endpoint = endpoint if isinstance(endpoint, EndPoint) else DefaultEndPoint(endpoint) - self.conviction_policy = conviction_policy_factory(self) if not host_id: raise ValueError("host_id may not be None") - self.host_id = host_id - self.set_location_info(datacenter, rack) - self.lock = RLock() + + object.__setattr__(self, "_initialized", False) + object.__setattr__(self, "endpoint", endpoint if isinstance(endpoint, EndPoint) else DefaultEndPoint(endpoint)) + object.__setattr__(self, "host_id", host_id) + object.__setattr__(self, "_datacenter", datacenter) + object.__setattr__(self, "_rack", rack) + object.__setattr__(self, "broadcast_address", broadcast_address) + object.__setattr__(self, "broadcast_port", broadcast_port) + object.__setattr__(self, "broadcast_rpc_address", broadcast_rpc_address) + object.__setattr__(self, "broadcast_rpc_port", broadcast_rpc_port) + object.__setattr__(self, "listen_address", listen_address) + object.__setattr__(self, "listen_port", listen_port) + object.__setattr__(self, "release_version", release_version) + object.__setattr__(self, "dse_version", dse_version) + object.__setattr__(self, "dse_workload", dse_workload) + object.__setattr__(self, "dse_workloads", dse_workloads) + + runtime_state = runtime_state or HostRuntimeState(conviction_policy_factory, self, event_bus=event_bus) + runtime_state.bind_host(self) + if event_bus is not None: + runtime_state.set_event_bus(event_bus) + object.__setattr__(self, "_runtime", runtime_state) + object.__setattr__(self, "_initialized", True) + + def __setattr__(self, name, value): + if getattr(self, "_initialized", False) and name in self._IMMUTABLE_FIELDS: + raise AttributeError("Host topology field %r is immutable; replace the Host snapshot instead" % (name,)) + object.__setattr__(self, name, value) @property def address(self): @@ -199,50 +301,117 @@ def rack(self): """ The rack the node is in. """ return self._rack + @property + def lock(self): + return self._runtime.lock + + @property + def conviction_policy(self): + return self._runtime.conviction_policy + + @property + def is_up(self): + return self._runtime.is_up + + @is_up.setter + def is_up(self, value): + self._runtime.is_up = value + + @property + def sharding_info(self): + return self._runtime.sharding_info + + @sharding_info.setter + def sharding_info(self, value): + self._runtime.set_sharding_info(self, value) + + @property + def _reconnection_handler(self): + return self._runtime._reconnection_handler + + @_reconnection_handler.setter + def _reconnection_handler(self, value): + self._runtime._reconnection_handler = value + + @property + def _currently_handling_node_up(self): + return self._runtime._currently_handling_node_up + + @_currently_handling_node_up.setter + def _currently_handling_node_up(self, value): + self._runtime._currently_handling_node_up = value + + @property + def runtime_state(self): + return self._runtime + def set_location_info(self, datacenter, rack): """ - Sets the datacenter and rack for this node. Intended for internal - use (by the control connection, which periodically checks the - ring topology) only. + Return a Host snapshot with updated datacenter and rack. + + Host topology is immutable after construction. Callers that need to + publish topology changes should use :meth:`Metadata.replace_host` so + related caches observe a single HOST_CHANGED event. + """ + return self.copy_with(datacenter=datacenter, rack=rack) + + def copy_with(self, endpoint=_NOT_SET, datacenter=_NOT_SET, rack=_NOT_SET, + broadcast_address=_NOT_SET, broadcast_port=_NOT_SET, + broadcast_rpc_address=_NOT_SET, broadcast_rpc_port=_NOT_SET, + listen_address=_NOT_SET, listen_port=_NOT_SET, + release_version=_NOT_SET, dse_version=_NOT_SET, + dse_workload=_NOT_SET, dse_workloads=_NOT_SET, + runtime_state=_NOT_SET): """ - self._datacenter = datacenter - self._rack = rack + Return a new immutable topology snapshot that shares this host's runtime state. + """ + runtime_state = self._runtime if runtime_state is _NOT_SET else runtime_state + return Host( + self.endpoint if endpoint is _NOT_SET else endpoint, + lambda host: runtime_state.conviction_policy, + self.datacenter if datacenter is _NOT_SET else datacenter, + self.rack if rack is _NOT_SET else rack, + host_id=self.host_id, + broadcast_address=self.broadcast_address if broadcast_address is _NOT_SET else broadcast_address, + broadcast_port=self.broadcast_port if broadcast_port is _NOT_SET else broadcast_port, + broadcast_rpc_address=self.broadcast_rpc_address if broadcast_rpc_address is _NOT_SET else broadcast_rpc_address, + broadcast_rpc_port=self.broadcast_rpc_port if broadcast_rpc_port is _NOT_SET else broadcast_rpc_port, + listen_address=self.listen_address if listen_address is _NOT_SET else listen_address, + listen_port=self.listen_port if listen_port is _NOT_SET else listen_port, + release_version=self.release_version if release_version is _NOT_SET else release_version, + dse_version=self.dse_version if dse_version is _NOT_SET else dse_version, + dse_workload=self.dse_workload if dse_workload is _NOT_SET else dse_workload, + dse_workloads=self.dse_workloads if dse_workloads is _NOT_SET else dse_workloads, + runtime_state=runtime_state) def set_up(self): - if not self.is_up: - log.debug("Host %s is now marked up", self.endpoint) - self.conviction_policy.reset() - self.is_up = True + self._runtime.set_up(self) def set_down(self): - self.is_up = False + self._runtime.set_down() def signal_connection_failure(self, connection_exc): - return self.conviction_policy.add_failure(connection_exc) + return self._runtime.signal_connection_failure(connection_exc) def is_currently_reconnecting(self): - return self._reconnection_handler is not None + return self._runtime.is_currently_reconnecting() def get_and_set_reconnection_handler(self, new_handler): """ Atomically replaces the reconnection handler for this host. Intended for internal use only. """ - with self.lock: - old = self._reconnection_handler - self._reconnection_handler = new_handler - return old + return self._runtime.get_and_set_reconnection_handler(new_handler) def __eq__(self, other): - if isinstance(other, Host): - return self.endpoint == other.endpoint - else: # TODO Backward compatibility, remove next major - return self.endpoint.address == other + return isinstance(other, Host) and self.host_id == other.host_id def __hash__(self): - return hash(self.endpoint) + return hash(self.host_id) def __lt__(self, other): + if self.endpoint == other.endpoint: + return str(self.host_id) < str(other.host_id) return self.endpoint < other.endpoint def __str__(self): @@ -442,6 +611,13 @@ def __init__(self, host, host_distance, session): log.debug("Finished initializing connection for host %s", self.host) + def rebind_host(self, host): + with self._lock: + old_host = self.host + self.host = host + if old_host is not host: + log.debug("Rebound connection pool from host %s to %s", old_host, host) + def _get_connection_for_routing_key(self, routing_key=None, keyspace=None, table=None): if self.is_shutdown: raise ConnectionException( @@ -920,5 +1096,3 @@ def open_count(self): @property def _excess_connection_limit(self): return self.host.sharding_info.shards_count * self.max_excess_connections_per_shard_multiplier - - diff --git a/tests/unit/test_control_connection.py b/tests/unit/test_control_connection.py index 037d4a8888..1ce12d6677 100644 --- a/tests/unit/test_control_connection.py +++ b/tests/unit/test_control_connection.py @@ -32,9 +32,9 @@ class MockMetadata(object): def __init__(self): self.hosts = { - 'uuid1': Host(endpoint=DefaultEndPoint("192.168.1.0"), conviction_policy_factory=SimpleConvictionPolicy, host_id='uuid1'), - 'uuid2': Host(endpoint=DefaultEndPoint("192.168.1.1"), conviction_policy_factory=SimpleConvictionPolicy, host_id='uuid2'), - 'uuid3': Host(endpoint=DefaultEndPoint("192.168.1.2"), conviction_policy_factory=SimpleConvictionPolicy, host_id='uuid3') + 'uuid1': Host(endpoint=DefaultEndPoint("192.168.1.0"), conviction_policy_factory=SimpleConvictionPolicy, host_id='uuid1', release_version="3.11"), + 'uuid2': Host(endpoint=DefaultEndPoint("192.168.1.1"), conviction_policy_factory=SimpleConvictionPolicy, host_id='uuid2', release_version="3.11"), + 'uuid3': Host(endpoint=DefaultEndPoint("192.168.1.2"), conviction_policy_factory=SimpleConvictionPolicy, host_id='uuid3', release_version="3.11") } self._host_id_by_endpoint = { DefaultEndPoint("192.168.1.0"): 'uuid1', @@ -43,7 +43,6 @@ def __init__(self): } for host in self.hosts.values(): host.set_up() - host.release_version = "3.11" self.cluster_name = None self.partitioner = None @@ -83,14 +82,32 @@ def update_host(self, host, old_endpoint): self._host_id_by_endpoint[host.endpoint] = host.host_id self._host_id_by_endpoint.pop(old_endpoint, False) + def replace_host(self, host_id, source=None, **fields): + old_host = self.hosts.get(host_id) + changed_fields = [] + for field, value in fields.items(): + if getattr(old_host, field) != value: + changed_fields.append(field) + + if not changed_fields: + return old_host, () + + new_host = old_host.copy_with(**dict((field, fields[field]) for field in changed_fields)) + self.hosts[host_id] = new_host + if 'endpoint' in changed_fields: + self._host_id_by_endpoint.pop(old_host.endpoint, False) + self._host_id_by_endpoint[new_host.endpoint] = host_id + return new_host, tuple(changed_fields) + def all_hosts_items(self): return list(self.hosts.items()) def remove_host_by_host_id(self, host_id, endpoint=None): if endpoint and self._host_id_by_endpoint[endpoint] == host_id: self._host_id_by_endpoint.pop(endpoint, False) - self.removed_hosts.append(self.hosts.pop(host_id, False)) - return bool(self.hosts.pop(host_id, False)) + removed = self.hosts.pop(host_id, False) + self.removed_hosts.append(removed) + return bool(removed) class MockCluster(object): @@ -118,8 +135,8 @@ def add_host(self, endpoint, datacenter, rack, signal=False, refresh_nodes=True, self.added_hosts.append(host) return host, True - def remove_host(self, host): - pass + def remove_host(self, host, source=None): + self.metadata.remove_host_by_host_id(host.host_id, host.endpoint) def on_up(self, host): pass @@ -420,10 +437,11 @@ def test_refresh_nodes_and_tokens_add_host(self): self.cluster.scheduler.schedule = lambda delay, f, *args, **kwargs: f(*args, **kwargs) self.control_connection.refresh_node_list_and_token_map() assert 1 == len(self.cluster.added_hosts) - assert self.cluster.added_hosts[0].address == "192.168.1.3" - assert self.cluster.added_hosts[0].datacenter == "dc1" - assert self.cluster.added_hosts[0].rack == "rack1" - assert self.cluster.added_hosts[0].host_id == "uuid4" + host = self.cluster.metadata.get_host_by_host_id("uuid4") + assert host.address == "192.168.1.3" + assert host.datacenter == "dc1" + assert host.rack == "rack1" + assert host.host_id == "uuid4" def test_refresh_nodes_and_tokens_remove_host(self): del self.connection.peer_results[1][1] @@ -594,14 +612,15 @@ def test_refresh_nodes_and_tokens_add_host_detects_port(self): self.cluster.scheduler.schedule = lambda delay, f, *args, **kwargs: f(*args, **kwargs) self.control_connection.refresh_node_list_and_token_map() assert 1 == len(self.cluster.added_hosts) - assert self.cluster.added_hosts[0].endpoint.address == "192.168.1.3" - assert self.cluster.added_hosts[0].endpoint.port == 555 - assert self.cluster.added_hosts[0].broadcast_rpc_address == "192.168.1.3" - assert self.cluster.added_hosts[0].broadcast_rpc_port == 555 - assert self.cluster.added_hosts[0].broadcast_address == "10.0.0.3" - assert self.cluster.added_hosts[0].broadcast_port == 666 - assert self.cluster.added_hosts[0].datacenter == "dc1" - assert self.cluster.added_hosts[0].rack == "rack1" + host = self.cluster.metadata.get_host_by_host_id("uuid4") + assert host.endpoint.address == "192.168.1.3" + assert host.endpoint.port == 555 + assert host.broadcast_rpc_address == "192.168.1.3" + assert host.broadcast_rpc_port == 555 + assert host.broadcast_address == "10.0.0.3" + assert host.broadcast_port == 666 + assert host.datacenter == "dc1" + assert host.rack == "rack1" def test_refresh_nodes_and_tokens_add_host_detects_invalid_port(self): del self.connection.peer_results[:] @@ -614,14 +633,15 @@ def test_refresh_nodes_and_tokens_add_host_detects_invalid_port(self): self.cluster.scheduler.schedule = lambda delay, f, *args, **kwargs: f(*args, **kwargs) self.control_connection.refresh_node_list_and_token_map() assert 1 == len(self.cluster.added_hosts) - assert self.cluster.added_hosts[0].endpoint.address == "192.168.1.3" - assert self.cluster.added_hosts[0].endpoint.port == 9042 # fallback default - assert self.cluster.added_hosts[0].broadcast_rpc_address == "192.168.1.3" - assert self.cluster.added_hosts[0].broadcast_rpc_port == None - assert self.cluster.added_hosts[0].broadcast_address == "10.0.0.3" - assert self.cluster.added_hosts[0].broadcast_port == None - assert self.cluster.added_hosts[0].datacenter == "dc1" - assert self.cluster.added_hosts[0].rack == "rack1" + host = self.cluster.metadata.get_host_by_host_id("uuid4") + assert host.endpoint.address == "192.168.1.3" + assert host.endpoint.port == 9042 # fallback default + assert host.broadcast_rpc_address == "192.168.1.3" + assert host.broadcast_rpc_port == None + assert host.broadcast_address == "10.0.0.3" + assert host.broadcast_port == None + assert host.datacenter == "dc1" + assert host.rack == "rack1" class EventTimingTest(unittest.TestCase): diff --git a/tests/unit/test_events.py b/tests/unit/test_events.py new file mode 100644 index 0000000000..d535acc398 --- /dev/null +++ b/tests/unit/test_events.py @@ -0,0 +1,469 @@ +from concurrent.futures import Future +import gc +import logging +from threading import RLock +import time +import uuid +import weakref + +from unittest.mock import ANY, Mock, patch + +import pytest + +from cassandra import ConsistencyLevel +from cassandra.cluster import (Cluster, ResponseFuture, Session, + _SessionHostEventHandler) +from cassandra.connection import DefaultEndPoint +from cassandra.events import (_EventBus, DriverEvent, HOST, HOST_ADDED, + HOST_CHANGED, HOST_DOWN, HostEventPayload) +from cassandra.metadata import Metadata +from cassandra.policies import (HostDistance, LoadBalancingPolicy, + RoundRobinPolicy, SimpleConvictionPolicy) +from cassandra.pool import Host +from cassandra.protocol import ProtocolHandler, QueryMessage +from cassandra.query import SimpleStatement + + +def _completed_future(result=True): + future = Future() + future.set_result(result) + return future + + +def _wait_for(predicate, timeout=1.0): + deadline = time.time() + timeout + while time.time() < deadline: + if predicate(): + return + time.sleep(0.01) + assert predicate() + + +class _RecordingPolicy(LoadBalancingPolicy): + + def __init__(self): + self.events = [] + self.hosts = [] + + def distance(self, host): + return HostDistance.LOCAL + + def populate(self, cluster, hosts): + self.hosts = list(hosts) + + def make_query_plan(self, working_keyspace=None, query=None): + return list(self.hosts) + + def on_up(self, host): + self.events.append(("up", host)) + + def on_down(self, host): + self.events.append(("down", host)) + + def on_add(self, host): + self.events.append(("add", host)) + + def on_remove(self, host): + self.events.append(("remove", host)) + + def on_change(self, old_host, new_host, changed_fields): + self.events.append(("change", old_host, new_host, changed_fields)) + + +class _FakeSession(object): + + def __init__(self): + self.pools = {} + self.update_created_pools_calls = 0 + + def add_or_renew_pool(self, host, is_host_addition): + self.pools[host.host_id] = host + return _completed_future(True) + + def update_created_pools(self): + self.update_created_pools_calls += 1 + + def remove_pool(self, host): + self.pools.pop(host.host_id, None) + return _completed_future(True) + + def shutdown(self): + pass + + +def test_event_bus_dispatches_type_then_category_and_dedupes(): + bus = _EventBus() + calls = [] + + def first(event): + calls.append(("first", event.type)) + + def second(event): + calls.append(("second", event.category)) + + bus.subscribe(HOST_ADDED, first) + bus.subscribe_category(HOST, second) + bus.subscribe_category(HOST, first) + + event = DriverEvent(HOST_ADDED, HOST, payload={"host": "h"}, source="test") + assert bus.publish(event) is event + assert calls == [("first", HOST_ADDED), ("second", HOST)] + + +def test_event_bus_unsubscribe_methods_are_idempotent(): + bus = _EventBus() + calls = [] + + def handler(event): + calls.append(event.type) + + bus.subscribe(HOST_ADDED, handler) + bus.subscribe_category(HOST, handler) + bus.unsubscribe(HOST_ADDED, handler) + bus.unsubscribe(HOST_ADDED, handler) + bus.unsubscribe_category(HOST, handler) + bus.unsubscribe_category(HOST, handler) + + bus.publish(DriverEvent(HOST_ADDED, HOST)) + assert calls == [] + + +def test_event_bus_isolates_subscriber_exceptions(caplog): + bus = _EventBus() + calls = [] + + def broken(event): + raise RuntimeError("boom") + + def working(event): + calls.append(event.type) + + bus.subscribe(HOST_ADDED, broken) + bus.subscribe(HOST_ADDED, working) + + with caplog.at_level(logging.ERROR): + bus.publish(DriverEvent(HOST_ADDED, HOST)) + + assert calls == [HOST_ADDED] + assert "Error dispatching driver event" in caplog.text + + +def test_host_identity_is_host_id_and_topology_fields_are_read_only(): + host_id = uuid.uuid4() + host = Host("127.0.0.1", SimpleConvictionPolicy, host_id=host_id) + same_identity = Host("127.0.0.2", SimpleConvictionPolicy, host_id=host_id) + other_identity = Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4()) + + assert host == same_identity + assert hash(host) == hash(same_identity) + assert host != other_identity + + with pytest.raises(AttributeError): + host.endpoint = DefaultEndPoint("127.0.0.9") + + with pytest.raises(AttributeError): + host.host_id = uuid.uuid4() + + with pytest.raises(AttributeError): + host._datacenter = "dc2" + + +def test_set_location_info_returns_replacement_without_mutating_host(): + host = Host("127.0.0.1", SimpleConvictionPolicy, datacenter="dc1", + rack="rack1", host_id=uuid.uuid4()) + + replacement = host.set_location_info("dc2", "rack2") + + assert replacement is not host + assert replacement.host_id == host.host_id + assert replacement.runtime_state is host.runtime_state + assert host.datacenter == "dc1" + assert host.rack == "rack1" + assert replacement.datacenter == "dc2" + assert replacement.rack == "rack2" + + +def test_host_replacement_reuses_runtime_state_and_updates_endpoint_index(): + bus = _EventBus() + events = [] + bus.subscribe(HOST_CHANGED, events.append) + metadata = Metadata(bus) + host_id = uuid.uuid4() + host, _ = metadata.add_or_return_host( + Host("127.0.0.1", SimpleConvictionPolicy, host_id=host_id)) + + host.set_down() + reconnector = object() + sharding_info = object() + host.get_and_set_reconnection_handler(reconnector) + host.sharding_info = sharding_info + new_host, changed_fields = metadata.replace_host( + host_id, endpoint=DefaultEndPoint("127.0.0.2"), datacenter="dc1") + + assert changed_fields == ("endpoint", "datacenter") + assert new_host is not host + assert new_host == host + assert new_host.runtime_state is host.runtime_state + assert new_host.is_up is False + assert new_host.sharding_info is sharding_info + assert new_host.get_and_set_reconnection_handler(None) is reconnector + assert metadata.get_host(DefaultEndPoint("127.0.0.1")) is None + assert metadata.get_host(DefaultEndPoint("127.0.0.2")) is new_host + assert events[-1].payload.old_host is host + assert events[-1].payload.new_host is new_host + + +def test_sharding_info_change_is_runtime_host_changed_event(): + bus = _EventBus() + events = [] + bus.subscribe(HOST_CHANGED, events.append) + metadata = Metadata(bus) + host_id = uuid.uuid4() + host, _ = metadata.add_or_return_host( + Host("127.0.0.1", SimpleConvictionPolicy, host_id=host_id)) + + sharding_info = object() + host.sharding_info = sharding_info + + assert host.sharding_info is sharding_info + assert len(events) == 1 + assert events[0].payload.old_host is host + assert events[0].payload.new_host is host + assert events[0].payload.changed_fields == ("sharding_info",) + assert events[0].payload.new_values["sharding_info"] is sharding_info + + +def test_public_on_add_fires_after_host_is_up_and_pools_are_ready(): + cluster = Cluster(protocol_version=4) + cluster._prepare_all_queries = Mock() + session = _FakeSession() + cluster.sessions.add(session) + observed = [] + + class Listener(object): + + def on_add(self, host): + observed.append((host.is_up, host.host_id in session.pools, + session.update_created_pools_calls)) + + def on_up(self, host): + pass + + def on_down(self, host): + pass + + def on_remove(self, host): + pass + + try: + cluster.register_listener(Listener()) + host = Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4(), + event_bus=cluster._event_bus) + cluster.metadata.add_or_return_host(host) + + cluster.on_add(host, refresh_nodes=False) + + assert observed == [(True, True, 1)] + finally: + cluster.shutdown() + + +def test_public_host_state_listener_fires_once_per_transition(): + cluster = Cluster(protocol_version=4) + cluster._start_reconnector = Mock() + observed = [] + + class Listener(object): + + def on_add(self, host): + observed.append("add") + + def on_up(self, host): + observed.append("up") + + def on_down(self, host): + observed.append("down") + + def on_remove(self, host): + observed.append("remove") + + try: + cluster.register_listener(Listener()) + host = Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4(), + event_bus=cluster._event_bus) + cluster.metadata.add_or_return_host(host) + + cluster.on_add(host, refresh_nodes=False) + host.set_down() + cluster.on_up(host) + cluster.on_down(host, is_host_addition=False) + _wait_for(lambda: "down" in observed) + cluster.on_remove(host) + + assert observed.count("add") == 1 + assert observed.count("up") == 1 + assert observed.count("down") == 1 + assert observed.count("remove") == 1 + finally: + cluster.shutdown() + + +def test_lbp_receives_one_notification_per_host_transition_and_change(): + policy = _RecordingPolicy() + cluster = Cluster(load_balancing_policy=policy, protocol_version=4) + cluster._start_reconnector = Mock() + + try: + host = Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4(), + event_bus=cluster._event_bus) + cluster.metadata.add_or_return_host(host) + + cluster.on_add(host, refresh_nodes=False) + host.set_down() + cluster.on_up(host) + cluster.on_down(host, is_host_addition=False) + _wait_for(lambda: [event[0] for event in policy.events].count("down") == 1) + cluster.on_remove(host) + + host.sharding_info = object() + cluster.metadata.add_or_return_host(host) + new_host, _ = cluster.metadata.replace_host(host.host_id, datacenter="dc2") + + event_names = [event[0] for event in policy.events] + assert event_names.count("add") == 1 + assert event_names.count("up") == 1 + assert event_names.count("down") == 1 + assert event_names.count("remove") == 1 + assert event_names.count("change") == 1 + assert policy.events[-1] == ("change", host, new_host, ("datacenter",)) + finally: + cluster.shutdown() + + +def test_topology_host_changed_replaces_round_robin_cached_host(): + policy = RoundRobinPolicy() + cluster = Cluster(load_balancing_policy=policy, protocol_version=4) + + try: + host = Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4(), + event_bus=cluster._event_bus) + host.set_up() + cluster.metadata.add_or_return_host(host) + policy.populate(cluster, [host]) + + new_host, _ = cluster.metadata.replace_host(host.host_id, datacenter="dc2") + query_plan = list(policy.make_query_plan()) + + assert any(candidate is new_host for candidate in query_plan) + assert not any(candidate is host for candidate in query_plan) + finally: + cluster.shutdown() + + +def _session_for_pool_tests(): + session = Session.__new__(Session) + session.cluster = Mock() + session.cluster.connect_timeout = 1 + session.cluster.signal_connection_failure = Mock() + session._profile_manager = Mock() + session._profile_manager.distance.return_value = HostDistance.LOCAL + session._pools = {} + session._lock = RLock() + session.keyspace = None + session.submit = lambda fn, *args, **kwargs: _completed_future(fn(*args, **kwargs)) + return session + + +def test_session_pools_are_keyed_by_host_id(): + session = _session_for_pool_tests() + host = Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4()) + + class FakePool(object): + + def __init__(self, host, host_distance, session): + self.host = host + self.host_distance = host_distance + self._keyspace = session.keyspace + self.is_shutdown = False + self.shutdown = Mock() + + with patch("cassandra.cluster.HostConnection", FakePool): + assert session.add_or_renew_pool(host, is_host_addition=False).result() is True + + assert set(session._pools) == {host.host_id} + assert session._pools[host.host_id].host is host + + +def test_session_rebinds_pool_for_non_endpoint_host_replacement(): + session = _session_for_pool_tests() + host = Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4()) + new_host = host.copy_with(datacenter="dc2") + pool = Mock() + pool.is_shutdown = False + pool.host_distance = HostDistance.LOCAL + session._pools[host.host_id] = pool + session.cluster.metadata.all_hosts.return_value = [new_host] + + session.on_change(host, new_host, ("datacenter",)) + + pool.rebind_host.assert_called_once_with(new_host) + + +def test_session_renews_pool_for_endpoint_host_replacement(): + session = _session_for_pool_tests() + host = Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4()) + new_host = host.copy_with(endpoint=DefaultEndPoint("127.0.0.2")) + session._pools[host.host_id] = Mock() + session.add_or_renew_pool = Mock(return_value="future") + + assert session.on_change(host, new_host, ("endpoint",)) == "future" + session.add_or_renew_pool.assert_called_once_with(new_host, is_host_addition=False) + + +def test_response_future_pool_lookup_uses_host_id(): + host = Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4()) + session = Mock(spec=Session) + session.keyspace = None + session.row_factory = lambda column_names, rows: rows + session.cluster.control_connection._tablets_routing_v1 = False + session.cluster._default_load_balancing_policy.make_query_plan.return_value = [host] + pool = Mock() + session._pools.get.side_effect = {host.host_id: pool}.get + connection = Mock() + pool.is_shutdown = False + pool.borrow_connection.return_value = (connection, 1) + + query = SimpleStatement("SELECT * FROM system.local") + message = QueryMessage(query=query.query_string, consistency_level=ConsistencyLevel.ONE) + future = ResponseFuture(session, message, query, 1) + + assert future.send_request() is True + session._pools.get.assert_any_call(host.host_id) + connection.send_msg.assert_called_once_with( + future.message, 1, cb=ANY, + encoder=ProtocolHandler.encode_message, + decoder=ProtocolHandler.decode_message, + result_metadata=[]) + + +def test_session_host_event_handler_unsubscribes_after_session_gc(): + bus = _EventBus() + + class DummySession(object): + + def _handle_host_event(self, event): + pass + + session = DummySession() + handler = _SessionHostEventHandler(session, bus) + session_ref = weakref.ref(session) + + del session + gc.collect() + + assert session_ref() is None + assert handler not in bus._type_subscribers[HOST_DOWN] + + host = Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4()) + bus.publish(DriverEvent(HOST_DOWN, HOST, HostEventPayload(host=host))) + assert handler not in bus._type_subscribers[HOST_DOWN] diff --git a/tests/unit/test_host_connection_pool.py b/tests/unit/test_host_connection_pool.py index f92bb53785..b5ea704881 100644 --- a/tests/unit/test_host_connection_pool.py +++ b/tests/unit/test_host_connection_pool.py @@ -217,13 +217,14 @@ def test_host_equality(self): Test host equality has correct logic """ - a = Host('127.0.0.1', SimpleConvictionPolicy, host_id=uuid.uuid4()) - b = Host('127.0.0.1', SimpleConvictionPolicy, host_id=uuid.uuid4()) - c = Host('127.0.0.2', SimpleConvictionPolicy, host_id=uuid.uuid4()) - - assert a == b, 'Two Host instances should be equal when sharing.' - assert a != c, 'Two Host instances should NOT be equal when using two different addresses.' - assert b != c, 'Two Host instances should NOT be equal when using two different addresses.' + host_id = uuid.uuid4() + a = Host('127.0.0.1', SimpleConvictionPolicy, host_id=host_id) + b = Host('127.0.0.2', SimpleConvictionPolicy, host_id=host_id) + c = Host('127.0.0.1', SimpleConvictionPolicy, host_id=uuid.uuid4()) + + assert a == b, 'Two Host instances should be equal when sharing host_id.' + assert a != c, 'Two Host instances should NOT be equal when using different host_id values.' + assert b != c, 'Two Host instances should NOT be equal when using different host_id values.' class HostConnectionTests(_PoolTests): diff --git a/tests/unit/test_metadata.py b/tests/unit/test_metadata.py index dcbb840447..926f97401e 100644 --- a/tests/unit/test_metadata.py +++ b/tests/unit/test_metadata.py @@ -43,6 +43,10 @@ log = logging.getLogger(__name__) +def _with_location(host, datacenter, rack): + return host.set_location_info(datacenter, rack) + + class ReplicationFactorTest(unittest.TestCase): def test_replication_factor_parsing(self): @@ -193,21 +197,20 @@ def test_nts_make_token_replica_map(self): dc1_1 = Host('dc1.1', SimpleConvictionPolicy, host_id=uuid.uuid4()) dc1_2 = Host('dc1.2', SimpleConvictionPolicy, host_id=uuid.uuid4()) dc1_3 = Host('dc1.3', SimpleConvictionPolicy, host_id=uuid.uuid4()) - for host in (dc1_1, dc1_2, dc1_3): - host.set_location_info('dc1', 'rack1') + dc1_1, dc1_2, dc1_3 = [_with_location(host, 'dc1', 'rack1') for host in (dc1_1, dc1_2, dc1_3)] token_to_host_owner[MD5Token(0)] = dc1_1 token_to_host_owner[MD5Token(100)] = dc1_2 token_to_host_owner[MD5Token(200)] = dc1_3 dc2_1 = Host('dc2.1', SimpleConvictionPolicy, host_id=uuid.uuid4()) dc2_2 = Host('dc2.2', SimpleConvictionPolicy, host_id=uuid.uuid4()) - dc2_1.set_location_info('dc2', 'rack1') - dc2_2.set_location_info('dc2', 'rack1') + dc2_1 = _with_location(dc2_1, 'dc2', 'rack1') + dc2_2 = _with_location(dc2_2, 'dc2', 'rack1') token_to_host_owner[MD5Token(1)] = dc2_1 token_to_host_owner[MD5Token(101)] = dc2_2 dc3_1 = Host('dc3.1', SimpleConvictionPolicy, host_id=uuid.uuid4()) - dc3_1.set_location_info('dc3', 'rack3') + dc3_1 = _with_location(dc3_1, 'dc3', 'rack3') token_to_host_owner[MD5Token(2)] = dc3_1 ring = [MD5Token(0), @@ -242,7 +245,7 @@ def test_nts_token_performance(self): for i in range(dc1hostnum): host = Host('dc1.{0}'.format(i), SimpleConvictionPolicy, host_id=uuid.uuid4()) - host.set_location_info('dc1', "rack1") + host = _with_location(host, 'dc1', "rack1") for vnode_num in range(vnodes_per_host): md5_token = MD5Token(current_token+vnode_num) token_to_host_owner[md5_token] = host @@ -269,10 +272,10 @@ def test_nts_make_token_replica_map_multi_rack(self): dc1_2 = Host('dc1.2', SimpleConvictionPolicy, host_id=uuid.uuid4()) dc1_3 = Host('dc1.3', SimpleConvictionPolicy, host_id=uuid.uuid4()) dc1_4 = Host('dc1.4', SimpleConvictionPolicy, host_id=uuid.uuid4()) - dc1_1.set_location_info('dc1', 'rack1') - dc1_2.set_location_info('dc1', 'rack1') - dc1_3.set_location_info('dc1', 'rack2') - dc1_4.set_location_info('dc1', 'rack2') + dc1_1 = _with_location(dc1_1, 'dc1', 'rack1') + dc1_2 = _with_location(dc1_2, 'dc1', 'rack1') + dc1_3 = _with_location(dc1_3, 'dc1', 'rack2') + dc1_4 = _with_location(dc1_4, 'dc1', 'rack2') token_to_host_owner[MD5Token(0)] = dc1_1 token_to_host_owner[MD5Token(100)] = dc1_2 token_to_host_owner[MD5Token(200)] = dc1_3 @@ -282,9 +285,9 @@ def test_nts_make_token_replica_map_multi_rack(self): dc2_1 = Host('dc2.1', SimpleConvictionPolicy, host_id=uuid.uuid4()) dc2_2 = Host('dc2.2', SimpleConvictionPolicy, host_id=uuid.uuid4()) dc2_3 = Host('dc2.3', SimpleConvictionPolicy, host_id=uuid.uuid4()) - dc2_1.set_location_info('dc2', 'rack1') - dc2_2.set_location_info('dc2', 'rack1') - dc2_3.set_location_info('dc2', 'rack2') + dc2_1 = _with_location(dc2_1, 'dc2', 'rack1') + dc2_2 = _with_location(dc2_2, 'dc2', 'rack1') + dc2_3 = _with_location(dc2_3, 'dc2', 'rack2') token_to_host_owner[MD5Token(1)] = dc2_1 token_to_host_owner[MD5Token(101)] = dc2_2 token_to_host_owner[MD5Token(201)] = dc2_3 @@ -305,7 +308,7 @@ def test_nts_make_token_replica_map_multi_rack(self): def test_nts_make_token_replica_map_empty_dc(self): host = Host('1', SimpleConvictionPolicy, host_id=uuid.uuid4()) - host.set_location_info('dc1', 'rack1') + host = _with_location(host, 'dc1', 'rack1') token_to_host_owner = {MD5Token(0): host} ring = [MD5Token(0)] nts = NetworkTopologyStrategy({'dc1': 1, 'dc2': 0}) diff --git a/tests/unit/test_policies.py b/tests/unit/test_policies.py index 6142af1aa1..23baa21e4c 100644 --- a/tests/unit/test_policies.py +++ b/tests/unit/test_policies.py @@ -40,6 +40,10 @@ from cassandra.tablets import Tablets, Tablet +def _with_location(host, datacenter, rack): + return host.set_location_info(datacenter, rack) + + class LoadBalancingPolicyTest(unittest.TestCase): def test_non_implemented(self): """ @@ -48,7 +52,7 @@ def test_non_implemented(self): policy = LoadBalancingPolicy() host = Host(DefaultEndPoint("ip1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) - host.set_location_info("dc1", "rack1") + host = _with_location(host, "dc1", "rack1") with pytest.raises(NotImplementedError): policy.distance(host) @@ -194,11 +198,11 @@ def test_no_remote(self, policy_specialization, constructor_args): hosts = [] for i in range(2): h = Host(DefaultEndPoint(i), SimpleConvictionPolicy, host_id=uuid.uuid4()) - h.set_location_info("dc1", "rack2") + h = _with_location(h, "dc1", "rack2") hosts.append(h) for i in range(2): h = Host(DefaultEndPoint(i + 2), SimpleConvictionPolicy, host_id=uuid.uuid4()) - h.set_location_info("dc1", "rack1") + h = _with_location(h, "dc1", "rack1") hosts.append(h) random.shuffle(hosts) @@ -210,12 +214,9 @@ def test_no_remote(self, policy_specialization, constructor_args): def test_with_remotes(self, policy_specialization, constructor_args): hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(6)] - for h in hosts[:2]: - h.set_location_info("dc1", "rack1") - for h in hosts[2:4]: - h.set_location_info("dc1", "rack2") - for h in hosts[4:]: - h.set_location_info("dc2", "rack1") + hosts[:2] = [_with_location(h, "dc1", "rack1") for h in hosts[:2]] + hosts[2:4] = [_with_location(h, "dc1", "rack2") for h in hosts[2:4]] + hosts[4:] = [_with_location(h, "dc2", "rack1") for h in hosts[4:]] random.shuffle(hosts) @@ -265,7 +266,7 @@ def test_get_distance(self, policy_specialization, constructor_args): # same dc, same rack host = Host(DefaultEndPoint("ip1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) - host.set_location_info("dc1", "rack1") + host = _with_location(host, "dc1", "rack1") policy.populate(Mock(), [host]) if isinstance(policy_specialization, DCAwareRoundRobinPolicy): @@ -275,14 +276,14 @@ def test_get_distance(self, policy_specialization, constructor_args): # same dc different rack host = Host(DefaultEndPoint("ip1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) - host.set_location_info("dc1", "rack2") + host = _with_location(host, "dc1", "rack2") policy.populate(Mock(), [host]) assert policy.distance(host) == HostDistance.LOCAL # used_hosts_per_remote_dc is set to 0, so ignore it remote_host = Host(DefaultEndPoint("ip2"), SimpleConvictionPolicy, host_id=uuid.uuid4()) - remote_host.set_location_info("dc2", "rack1") + remote_host = _with_location(remote_host, "dc2", "rack1") assert policy.distance(remote_host) == HostDistance.IGNORED # dc2 isn't registered in the policy's live_hosts dict @@ -296,19 +297,16 @@ def test_get_distance(self, policy_specialization, constructor_args): # since used_hosts_per_remote_dc is set to 1, only the first # remote host in dc2 will be REMOTE, the rest are IGNORED second_remote_host = Host(DefaultEndPoint("ip3"), SimpleConvictionPolicy, host_id=uuid.uuid4()) - second_remote_host.set_location_info("dc2", "rack1") + second_remote_host = _with_location(second_remote_host, "dc2", "rack1") policy.populate(Mock(), [host, remote_host, second_remote_host]) distances = set([policy.distance(remote_host), policy.distance(second_remote_host)]) assert distances == set([HostDistance.REMOTE, HostDistance.IGNORED]) def test_status_updates(self, policy_specialization, constructor_args): hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(5)] - for h in hosts[:2]: - h.set_location_info("dc1", "rack1") - for h in hosts[2:4]: - h.set_location_info("dc1", "rack2") - for h in hosts[4:]: - h.set_location_info("dc2", "rack1") + hosts[:2] = [_with_location(h, "dc1", "rack1") for h in hosts[:2]] + hosts[2:4] = [_with_location(h, "dc1", "rack2") for h in hosts[2:4]] + hosts[4:] = [_with_location(h, "dc2", "rack1") for h in hosts[4:]] policy = policy_specialization(*constructor_args, used_hosts_per_remote_dc=1) policy.populate(Mock(), hosts) @@ -316,11 +314,11 @@ def test_status_updates(self, policy_specialization, constructor_args): policy.on_remove(hosts[2]) new_local_host = Host(DefaultEndPoint(5), SimpleConvictionPolicy, host_id=uuid.uuid4()) - new_local_host.set_location_info("dc1", "rack1") + new_local_host = _with_location(new_local_host, "dc1", "rack1") policy.on_up(new_local_host) new_remote_host = Host(DefaultEndPoint(6), SimpleConvictionPolicy, host_id=uuid.uuid4()) - new_remote_host.set_location_info("dc9000", "rack1") + new_remote_host = _with_location(new_remote_host, "dc9000", "rack1") policy.on_add(new_remote_host) # we now have three local hosts and two remote hosts in separate dcs @@ -345,10 +343,8 @@ def test_status_updates(self, policy_specialization, constructor_args): def test_modification_during_generation(self, policy_specialization, constructor_args): hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(4)] - for h in hosts[:2]: - h.set_location_info("dc1", "rack1") - for h in hosts[2:]: - h.set_location_info("dc2", "rack1") + hosts[:2] = [_with_location(h, "dc1", "rack1") for h in hosts[:2]] + hosts[2:] = [_with_location(h, "dc2", "rack1") for h in hosts[2:]] policy = policy_specialization(*constructor_args, used_hosts_per_remote_dc=3) policy.populate(Mock(), hosts) @@ -359,7 +355,7 @@ def test_modification_during_generation(self, policy_specialization, constructor # generator. new_host = Host(DefaultEndPoint(4), SimpleConvictionPolicy, host_id=uuid.uuid4()) - new_host.set_location_info("dc1", "rack1") + new_host = _with_location(new_host, "dc1", "rack1") # new local before iteration plan = policy.make_query_plan() @@ -389,7 +385,7 @@ def test_modification_during_generation(self, policy_specialization, constructor assert len(list(plan)) == 0 + 2 # REMOTES CHANGE - new_host.set_location_info("dc2", "rack1") + new_host = _with_location(new_host, "dc2", "rack1") # new remote after traversing local, but not starting remote plan = policy.make_query_plan() @@ -470,8 +466,8 @@ def test_modification_during_generation(self, policy_specialization, constructor policy.on_up(hosts[3]) another_host = Host(DefaultEndPoint(5), SimpleConvictionPolicy, host_id=uuid.uuid4()) - another_host.set_location_info("dc3", "rack1") - new_host.set_location_info("dc3", "rack1") + another_host = _with_location(another_host, "dc3", "rack1") + new_host = _with_location(new_host, "dc3", "rack1") # new DC while traversing remote plan = policy.make_query_plan() @@ -504,7 +500,7 @@ def test_no_live_nodes(self, policy_specialization, constructor_args): hosts = [] for i in range(4): h = Host(DefaultEndPoint(i), SimpleConvictionPolicy, host_id=uuid.uuid4()) - h.set_location_info("dc1", "rack1") + h = _with_location(h, "dc1", "rack1") hosts.append(h) policy = policy_specialization(*constructor_args, used_hosts_per_remote_dc=1) @@ -529,8 +525,7 @@ def test_no_nodes(self, policy_specialization, constructor_args): def test_wrong_dc(self, policy_specialization, constructor_args): hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(3)] - for h in hosts[:3]: - h.set_location_info("dc2", "rack2") + hosts[:3] = [_with_location(h, "dc2", "rack2") for h in hosts[:3]] policy = policy_specialization(*constructor_args, used_hosts_per_remote_dc=0) policy.populate(Mock(), hosts) @@ -613,10 +608,8 @@ def test_wrap_dc_aware(self): hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(4)] for host in hosts: host.set_up() - for h in hosts[:2]: - h.set_location_info("dc1", "rack1") - for h in hosts[2:]: - h.set_location_info("dc2", "rack1") + hosts[:2] = [_with_location(h, "dc1", "rack1") for h in hosts[:2]] + hosts[2:] = [_with_location(h, "dc2", "rack1") for h in hosts[2:]] def get_replicas(keyspace, packed_key): index = struct.unpack('>i', packed_key)[0] @@ -662,14 +655,14 @@ def test_wrap_rack_aware(self): hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(8)] for host in hosts: host.set_up() - hosts[0].set_location_info("dc1", "rack1") - hosts[1].set_location_info("dc1", "rack2") - hosts[2].set_location_info("dc2", "rack1") - hosts[3].set_location_info("dc2", "rack2") - hosts[4].set_location_info("dc1", "rack1") - hosts[5].set_location_info("dc1", "rack2") - hosts[6].set_location_info("dc2", "rack1") - hosts[7].set_location_info("dc2", "rack2") + hosts[0] = _with_location(hosts[0], "dc1", "rack1") + hosts[1] = _with_location(hosts[1], "dc1", "rack2") + hosts[2] = _with_location(hosts[2], "dc2", "rack1") + hosts[3] = _with_location(hosts[3], "dc2", "rack2") + hosts[4] = _with_location(hosts[4], "dc1", "rack1") + hosts[5] = _with_location(hosts[5], "dc1", "rack2") + hosts[6] = _with_location(hosts[6], "dc2", "rack1") + hosts[7] = _with_location(hosts[7], "dc2", "rack2") def get_replicas(keyspace, packed_key): index = struct.unpack('>i', packed_key)[0] @@ -724,7 +717,7 @@ def test_get_distance(self): policy = TokenAwarePolicy(DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=0)) host = Host(DefaultEndPoint("ip1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) - host.set_location_info("dc1", "rack1") + host = _with_location(host, "dc1", "rack1") policy.populate(self.FakeCluster(), [host]) @@ -732,7 +725,7 @@ def test_get_distance(self): # used_hosts_per_remote_dc is set to 0, so ignore it remote_host = Host(DefaultEndPoint("ip2"), SimpleConvictionPolicy, host_id=uuid.uuid4()) - remote_host.set_location_info("dc2", "rack1") + remote_host = _with_location(remote_host, "dc2", "rack1") assert policy.distance(remote_host) == HostDistance.IGNORED # dc2 isn't registered in the policy's live_hosts dict @@ -746,7 +739,7 @@ def test_get_distance(self): # since used_hosts_per_remote_dc is set to 1, only the first # remote host in dc2 will be REMOTE, the rest are IGNORED second_remote_host = Host(DefaultEndPoint("ip3"), SimpleConvictionPolicy, host_id=uuid.uuid4()) - second_remote_host.set_location_info("dc2", "rack1") + second_remote_host = _with_location(second_remote_host, "dc2", "rack1") policy.populate(self.FakeCluster(), [host, remote_host, second_remote_host]) distances = set([policy.distance(remote_host), policy.distance(second_remote_host)]) assert distances == set([HostDistance.REMOTE, HostDistance.IGNORED]) @@ -757,10 +750,8 @@ def test_status_updates(self): """ hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy, host_id=uuid.uuid4()) for i in range(4)] - for h in hosts[:2]: - h.set_location_info("dc1", "rack1") - for h in hosts[2:]: - h.set_location_info("dc2", "rack1") + hosts[:2] = [_with_location(h, "dc1", "rack1") for h in hosts[:2]] + hosts[2:] = [_with_location(h, "dc2", "rack1") for h in hosts[2:]] policy = TokenAwarePolicy(DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=1)) policy.populate(self.FakeCluster(), hosts) @@ -768,11 +759,11 @@ def test_status_updates(self): policy.on_remove(hosts[2]) new_local_host = Host(DefaultEndPoint(4), SimpleConvictionPolicy, host_id=uuid.uuid4()) - new_local_host.set_location_info("dc1", "rack1") + new_local_host = _with_location(new_local_host, "dc1", "rack1") policy.on_up(new_local_host) new_remote_host = Host(DefaultEndPoint(5), SimpleConvictionPolicy, host_id=uuid.uuid4()) - new_remote_host.set_location_info("dc9000", "rack1") + new_remote_host = _with_location(new_remote_host, "dc9000", "rack1") policy.on_add(new_remote_host) # we now have two local hosts and two remote hosts in separate dcs @@ -1647,9 +1638,8 @@ def get_replicas(keyspace, packed_key): query_plan = hfp.make_query_plan("keyspace", mocked_query) # First the not filtered replica, and then the rest of the allowed hosts ordered query_plan = list(query_plan) - assert query_plan[0] == Host(DefaultEndPoint("127.0.0.2"), SimpleConvictionPolicy, host_id=uuid.uuid4()) - assert set(query_plan[1:]) == {Host(DefaultEndPoint("127.0.0.3"), SimpleConvictionPolicy, host_id=uuid.uuid4()), - Host(DefaultEndPoint("127.0.0.5"), SimpleConvictionPolicy, host_id=uuid.uuid4())} + assert query_plan[0].address == "127.0.0.2" + assert {host.address for host in query_plan[1:]} == {"127.0.0.3", "127.0.0.5"} def test_create_whitelist(self): cluster = Mock(spec=Cluster) @@ -1671,5 +1661,4 @@ def test_create_whitelist(self): mocked_query = Mock() query_plan = hfp.make_query_plan("keyspace", mocked_query) # Only the filtered replicas should be allowed - assert set(query_plan) == {Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4()), - Host(DefaultEndPoint("127.0.0.4"), SimpleConvictionPolicy, host_id=uuid.uuid4())} + assert {host.address for host in query_plan} == {"127.0.0.1", "127.0.0.4"} diff --git a/tests/unit/test_types.py b/tests/unit/test_types.py index 11aab2748d..62aa01a78e 100644 --- a/tests/unit/test_types.py +++ b/tests/unit/test_types.py @@ -1020,12 +1020,13 @@ def test_host_order(self): """ hosts = [Host(addr, SimpleConvictionPolicy, host_id=uuid.uuid4()) for addr in ("127.0.0.1", "127.0.0.2", "127.0.0.3", "127.0.0.4")] - hosts_equal = [Host(addr, SimpleConvictionPolicy, host_id=uuid.uuid4()) for addr in - ("127.0.0.1", "127.0.0.1")] - hosts_equal_conviction = [Host("127.0.0.1", SimpleConvictionPolicy, host_id=uuid.uuid4()), Host("127.0.0.1", ConvictionPolicy, host_id=uuid.uuid4())] + hosts_equal = [Host("127.0.0.1", SimpleConvictionPolicy, host_id="a"), + Host("127.0.0.1", SimpleConvictionPolicy, host_id="b")] + hosts_equal_conviction = [Host("127.0.0.1", SimpleConvictionPolicy, host_id="a"), + Host("127.0.0.1", ConvictionPolicy, host_id="b")] check_sequence_consistency(hosts) - check_sequence_consistency(hosts_equal, equal=True) - check_sequence_consistency(hosts_equal_conviction, equal=True) + check_sequence_consistency(hosts_equal) + check_sequence_consistency(hosts_equal_conviction) def test_date_order(self): """