Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 186 additions & 60 deletions cassandra/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1870,6 +1870,26 @@ def _cleanup_failed_on_up_handling(self, host):

self._start_reconnector(host, is_host_addition=False)

def _cleanup_failed_on_add_handling(self, host, addition_generation=None):
with host.lock:
if (addition_generation is not None and
getattr(host, "_node_addition_generation", None) is not addition_generation):
return
host.set_down()

self.profile_manager.on_down(host)
self.control_connection.on_down(host)
for session in tuple(self.sessions):
session.remove_pool(host)

with host.lock:
if (addition_generation is None or
getattr(host, "_node_addition_generation", None) is addition_generation):
host._currently_handling_node_addition = False
host._node_addition_generation = None

self._start_reconnector(host, is_host_addition=True)

def _on_up_future_completed(self, host, futures, results, lock, finished_future):
with lock:
futures.discard(finished_future)
Expand Down Expand Up @@ -1916,8 +1936,9 @@ def on_up(self, host):

log.debug("Waiting to acquire lock for handling up status of node %s", host)
with host.lock:
if host._currently_handling_node_up:
log.debug("Another thread is already handling up status of node %s", host)
if (host._currently_handling_node_up or
getattr(host, "_currently_handling_node_addition", False)):
log.debug("Another thread is already handling up/add status of node %s", host)
return

if host.is_up:
Expand Down Expand Up @@ -1958,8 +1979,10 @@ def on_up(self, host):
future = session.add_or_renew_pool(host, is_host_addition=False)
if future is not None:
have_future = True
future.add_done_callback(callback)
futures.add(future)

for future in tuple(futures):
future.add_done_callback(callback)
except Exception:
log.exception("Unexpected failure handling node %s being marked up:", host)
for future in futures:
Expand Down Expand Up @@ -2050,69 +2073,102 @@ def on_add(self, host, refresh_nodes=True):

log.debug("Handling new host %r and notifying listeners", host)

self.profile_manager.on_add(host)
self.control_connection.on_add(host, refresh_nodes)

distance = self.profile_manager.distance(host)
if distance != HostDistance.IGNORED:
self._prepare_all_queries(host)
log.debug("Done preparing queries for new host %r", host)

if distance == HostDistance.IGNORED:
log.debug("Not adding connection pool for new host %r because the "
"load balancing policy has marked it as IGNORED", host)
self._finalize_add(host, set_up=False)
return
# Keep refresh-time pool rebuilds from racing this host's pool creation.
addition_generation = object()
with host.lock:
if getattr(host, "_currently_handling_node_addition", False):
log.debug("Another thread is already handling add status of node %s", host)
return
host._currently_handling_node_addition = True
host._node_addition_generation = addition_generation

have_future = False
add_aborted = False
futures = {}
futures_lock = Lock()
futures_results = []
futures = set()
finalize_add = None
try:
self.profile_manager.on_add(host)
self.control_connection.on_add(host, refresh_nodes)

def future_completed(future):
with futures_lock:
futures.discard(future)
distance = self.profile_manager.distance(host)
if distance != HostDistance.IGNORED:
self._prepare_all_queries(host)
log.debug("Done preparing queries for new host %r", host)

try:
futures_results.append(future.result())
except Exception as exc:
futures_results.append(exc)
if distance == HostDistance.IGNORED:
log.debug("Not adding connection pool for new host %r because the "
"load balancing policy has marked it as IGNORED", host)
finalize_add = False
else:
def future_completed(future):
with futures_lock:
futures.pop(future, None)

if futures:
return
if add_aborted:
return

log.debug('All futures have completed for added host %s', host)
try:
futures_results.append(future.result())
except Exception as exc:
futures_results.append(exc)

for exc in [f for f in futures_results if isinstance(f, Exception)]:
log.error("Unexpected failure while adding node %s, will not mark up:", host, exc_info=exc)
return
if futures:
return

if not all(futures_results):
log.warning("Connection pool could not be created, not marking node %s up", host)
return
log.debug('All futures have completed for added host %s', host)

self._finalize_add(host)
for exc in [f for f in futures_results if isinstance(f, Exception)]:
log.error("Unexpected failure while adding node %s, will not mark up:", host, exc_info=exc)
self._cleanup_failed_on_add_handling(host, addition_generation)
return

have_future = False
for session in tuple(self.sessions):
future = session.add_or_renew_pool(host, is_host_addition=True)
if future is not None:
have_future = True
futures.add(future)
future.add_done_callback(future_completed)
if not all(futures_results):
log.warning("Connection pool could not be created, not marking node %s up", host)
self._cleanup_failed_on_add_handling(host, addition_generation)
return

if not have_future:
self._finalize_add(host)
self._finalize_add(host, addition_generation=addition_generation)

def _finalize_add(self, host, set_up=True):
if set_up:
host.set_up()
for session in tuple(self.sessions):
future = session.add_or_renew_pool(host, is_host_addition=True)
if future is not None:
have_future = True
futures[future] = session

for listener in self.listeners:
listener.on_add(host)
for future in tuple(futures):
future.add_done_callback(future_completed)

# see if there are any pools to add or remove now that the host is marked up
for session in tuple(self.sessions):
session.update_created_pools()
if not have_future:
finalize_add = True
except Exception:
add_aborted = True
for future, session in tuple(futures.items()):
future.cancel()
self._cleanup_failed_on_add_handling(host, addition_generation)
raise

if finalize_add is not None:
self._finalize_add(host, set_up=finalize_add, addition_generation=addition_generation)

def _finalize_add(self, host, set_up=True, addition_generation=None):
try:
if set_up:
host.set_up()

for listener in self.listeners:
listener.on_add(host)

# see if there are any pools to add or remove now that the host is marked up
for session in tuple(self.sessions):
session.update_created_pools()
finally:
with host.lock:
if (addition_generation is None or
getattr(host, "_node_addition_generation", None) is addition_generation):
host._currently_handling_node_addition = False
host._node_addition_generation = None

def on_remove(self, host):
if self.is_shutdown:
Expand All @@ -2137,7 +2193,8 @@ def signal_connection_failure(self, host, connection_exc, is_host_addition, expe
self.on_down(host, is_host_addition, expect_host_to_be_down)
return is_down

def add_host(self, endpoint, datacenter=None, rack=None, signal=True, refresh_nodes=True, host_id=None):
def add_host(self, endpoint, datacenter=None, rack=None, signal=True, refresh_nodes=True, host_id=None,
is_zero_token=None):
"""
Called when adding initial contact points and when the control
connection subsequently discovers a new node.
Expand All @@ -2147,8 +2204,16 @@ def add_host(self, endpoint, datacenter=None, rack=None, signal=True, refresh_no
"""
with self.metadata._hosts_lock:
if endpoint in self.metadata._host_id_by_endpoint:
return self.metadata._hosts[self.metadata._host_id_by_endpoint[endpoint]], False
host, new = self.metadata.add_or_return_host(Host(endpoint, self.conviction_policy_factory, datacenter, rack, host_id=host_id))
host = self.metadata._hosts[self.metadata._host_id_by_endpoint[endpoint]]
if is_zero_token is not None:
host.is_zero_token = is_zero_token
return host, False
host = Host(endpoint, self.conviction_policy_factory, datacenter, rack, host_id=host_id)
if is_zero_token is not None:
host.is_zero_token = is_zero_token
host, new = self.metadata.add_or_return_host(host)
if not new and is_zero_token is not None:
host.is_zero_token = is_zero_token
if new and signal:
log.info("New Cassandra host %r discovered", host)
self.on_add(host, refresh_nodes)
Expand Down Expand Up @@ -3239,15 +3304,27 @@ def add_or_renew_pool(self, host, is_host_addition):
distance = self._profile_manager.distance(host)
if distance == HostDistance.IGNORED:
return None
addition_generation = getattr(host, "_node_addition_generation", None) if is_host_addition else None

def is_stale_addition():
return (is_host_addition and
getattr(host, "_node_addition_generation", None) is not addition_generation)

def run_add_or_renew_pool():
if is_stale_addition():
return False

try:
new_pool = HostConnection(host, distance, self)
new_pool = HostConnection(host, distance, self)
except AuthenticationFailed as auth_exc:
if is_stale_addition():
return False
conn_exc = ConnectionException(str(auth_exc), endpoint=host)
self.cluster.signal_connection_failure(host, conn_exc, is_host_addition)
return False
except Exception as conn_exc:
if is_stale_addition():
return False
log.warning("Failed to create connection pool for new host %s:",
host, exc_info=conn_exc)
# the host itself will still be marked down, so we need to pass
Expand All @@ -3256,6 +3333,10 @@ def run_add_or_renew_pool():
host, conn_exc, is_host_addition, expect_host_to_be_down=True)
return False

if is_stale_addition():
new_pool.shutdown()
return False

previous = self._pools.get(host)
with self._lock:
while new_pool._keyspace != self.keyspace:
Expand All @@ -3270,12 +3351,19 @@ def callback(pool, errors):
new_pool._set_keyspace_for_all_conns(self.keyspace, callback)
set_keyspace_event.wait(self.cluster.connect_timeout)
if not set_keyspace_event.is_set() or errors_returned:
if is_stale_addition():
new_pool.shutdown()
self._lock.acquire()
return False
log.warning("Failed setting keyspace for pool after keyspace changed during connect: %s", errors_returned)
self.cluster.on_down(host, is_host_addition)
new_pool.shutdown()
self._lock.acquire()
return False
self._lock.acquire()
if is_stale_addition():
new_pool.shutdown()
return False
self._pools[host] = new_pool

log.debug("Added pool for host %s to session", host)
Expand Down Expand Up @@ -3315,7 +3403,10 @@ def update_created_pools(self):
# we don't eagerly set is_up on previously ignored hosts. None is included here
# to allow us to attempt connections to hosts that have gone from ignored to something
# else.
if distance != HostDistance.IGNORED and host.is_up in (True, None):
# on_up() and on_add() already own pool creation for hosts in flight.
if (distance != HostDistance.IGNORED and host.is_up in (True, None) and
not host._currently_handling_node_up and
not getattr(host, "_currently_handling_node_addition", False)):
future = self.add_or_renew_pool(host, False)
elif distance != pool.host_distance:
# the distance has changed
Expand Down Expand Up @@ -3864,6 +3955,8 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None,
# every node in the cluster was in the contact points, we won't discover
# any new nodes, so we need this additional check. (See PYTHON-90)
should_rebuild_token_map = force_token_rebuild or self._cluster.metadata.partitioner is None
zero_token_status_changed = False
promoted_zero_token_hosts = []
for row in peers_result:
if not self._is_valid_peer(row):
continue
Expand All @@ -3884,10 +3977,20 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None,
host = self._cluster.metadata.get_host(endpoint)
datacenter = row.get("data_center")
rack = row.get("rack")
tokens = row.get("tokens", None)
has_token_status = "tokens" in row
is_zero_token = has_token_status and not tokens
token_status = is_zero_token if has_token_status else None

if host is None:
host = self._cluster.metadata.get_host_by_host_id(host_id)
if host and host.endpoint != endpoint:
if has_token_status:
status_changed = self._update_zero_token_info(host, is_zero_token)
zero_token_status_changed |= status_changed
should_rebuild_token_map |= status_changed
if status_changed and not is_zero_token and host.is_up is not True:
promoted_zero_token_hosts.append(host)
log.debug("[control connection] Updating host ip from %s to %s for (%s)", host.endpoint, endpoint, host_id)
reconnector = host.get_and_set_reconnection_handler(None)
if reconnector:
Expand All @@ -3901,11 +4004,20 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None,

if host is None:
log.debug("[control connection] Found new host to connect to: %s", endpoint)
host, _ = self._cluster.add_host(endpoint, datacenter=datacenter, rack=rack, signal=True, refresh_nodes=False, host_id=host_id)
host, _ = self._cluster.add_host(endpoint, datacenter=datacenter, rack=rack, signal=True,
refresh_nodes=False, host_id=host_id,
is_zero_token=token_status)
should_rebuild_token_map = True
else:
should_rebuild_token_map |= self._update_location_info(host, datacenter, rack)

if has_token_status:
status_changed = self._update_zero_token_info(host, is_zero_token)
zero_token_status_changed |= status_changed
should_rebuild_token_map |= status_changed
if status_changed and not is_zero_token and host.is_up is not True:
promoted_zero_token_hosts.append(host)

host.host_id = host_id
host.broadcast_address = _NodeInfo.get_broadcast_address(row)
host.broadcast_port = _NodeInfo.get_broadcast_port(row)
Expand All @@ -3916,7 +4028,6 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None,
host.dse_workload = row.get("workload")
host.dse_workloads = row.get("workloads")

tokens = row.get("tokens", None)
if partitioner and tokens and self._token_meta_enabled:
token_map[host] = tokens
self._cluster.metadata.update_host(host, old_endpoint=endpoint)
Expand All @@ -3932,6 +4043,22 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None,
log.debug("[control connection] Rebuilding token map due to topology changes")
self._cluster.metadata.rebuild_token_map(partitioner, token_map)

for host in promoted_zero_token_hosts:
self._cluster.on_up(host)

if zero_token_status_changed:
for session in tuple(getattr(self._cluster, "sessions", ())):
session.update_created_pools()

@staticmethod
def _update_zero_token_info(host, is_zero_token):
is_zero_token = bool(is_zero_token)
if host.is_zero_token == is_zero_token:
return False

host.is_zero_token = is_zero_token
return True

@staticmethod
def _is_valid_peer(row):
broadcast_rpc = _NodeInfo.get_broadcast_rpc_address(row)
Expand Down Expand Up @@ -3963,9 +4090,8 @@ def _is_valid_peer(row):

if "tokens" in row and not row.get("tokens"):
log.debug(
"Found a zero-token node - tokens is None (broadcast_rpc: %s, host_id: %s). Ignoring host." %
"Found a zero-token node - tokens are empty (broadcast_rpc: %s, host_id: %s). Adding host without tokens." %
(broadcast_rpc, host_id))
return False

return True

Expand Down
Loading
Loading