Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
301 changes: 242 additions & 59 deletions cassandra/cluster.py

Large diffs are not rendered by default.

150 changes: 150 additions & 0 deletions cassandra/events.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Internal driver event primitives.

This module intentionally does not expose a public subscription API. It is a
small synchronous bus used to decouple driver subsystems that need to react to
shared internal state changes.
"""

from collections import defaultdict
import logging
from threading import RLock


log = logging.getLogger(__name__)


HOST = "HOST"

HOST_ADDED = "HOST_ADDED"
HOST_REMOVED = "HOST_REMOVED"
HOST_UP = "HOST_UP"
HOST_DOWN = "HOST_DOWN"
HOST_CHANGED = "HOST_CHANGED"


class DriverEvent(object):
"""
Internal event envelope.
"""

__slots__ = ("type", "category", "payload", "source")

def __init__(self, event_type, category, payload=None, source=None):
self.type = event_type
self.category = category
self.payload = payload
self.source = source

def __repr__(self):
return "%s(type=%r, category=%r, payload=%r, source=%r)" % (
self.__class__.__name__, self.type, self.category, self.payload, self.source)


class HostEventPayload(object):
"""
Payload for host topology and runtime-state events.
"""

__slots__ = (
"host_id", "host", "old_host", "new_host", "changed_fields",
"old_values", "new_values", "refresh_nodes")

def __init__(self, host=None, host_id=None, old_host=None, new_host=None,
changed_fields=(), old_values=None, new_values=None,
refresh_nodes=True):
if host is None:
host = new_host if new_host is not None else old_host

if host_id is None and host is not None:
host_id = host.host_id

self.host_id = host_id
self.host = host
self.old_host = old_host
self.new_host = new_host
self.changed_fields = tuple(changed_fields or ())
self.old_values = old_values or {}
self.new_values = new_values or {}
self.refresh_nodes = refresh_nodes

def __repr__(self):
return ("%s(host_id=%r, host=%r, old_host=%r, new_host=%r, "
"changed_fields=%r, old_values=%r, new_values=%r)") % (
self.__class__.__name__, self.host_id, self.host,
self.old_host, self.new_host, self.changed_fields,
self.old_values, self.new_values)


class _EventBus(object):
"""
Synchronous internal event bus.
"""

def __init__(self):
self._type_subscribers = defaultdict(list)
self._category_subscribers = defaultdict(list)
self._lock = RLock()

def subscribe(self, event_type, handler):
with self._lock:
handlers = self._type_subscribers[event_type]
if handler not in handlers:
handlers.append(handler)

def unsubscribe(self, event_type, handler):
with self._lock:
self._remove_handler(self._type_subscribers.get(event_type), handler)

def subscribe_category(self, category, handler):
with self._lock:
handlers = self._category_subscribers[category]
if handler not in handlers:
handlers.append(handler)

def unsubscribe_category(self, category, handler):
with self._lock:
self._remove_handler(self._category_subscribers.get(category), handler)

def publish(self, event):
handlers = self._handlers_for_event(event)
for handler in handlers:
try:
handler(event)
except Exception:
log.exception("Error dispatching driver event %s to %r", event.type, handler)
return event

@staticmethod
def _remove_handler(handlers, handler):
if not handlers:
return
try:
handlers.remove(handler)
except ValueError:
pass

def _handlers_for_event(self, event):
with self._lock:
raw_handlers = list(self._type_subscribers.get(event.type, ()))
raw_handlers.extend(self._category_subscribers.get(event.category, ()))

handlers = []
for handler in raw_handlers:
if handler not in handlers:
handlers.append(handler)
return handlers
81 changes: 79 additions & 2 deletions cassandra/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from cassandra import SignatureDescriptor, ConsistencyLevel, InvalidRequest, Unauthorized
import cassandra.cqltypes as types
from cassandra.encoder import Encoder
from cassandra.events import DriverEvent, HOST, HOST_CHANGED, HostEventPayload
from cassandra.marshal import varint_unpack
from cassandra.protocol import QueryMessage
from cassandra.query import dict_factory, bind_params
Expand Down Expand Up @@ -121,14 +122,22 @@ class Metadata(object):
dbaas = False
""" A boolean indicating if connected to a DBaaS cluster """

def __init__(self):
def __init__(self, event_bus=None):
self.keyspaces = {}
self.dbaas = False
self._hosts = {}
self._host_id_by_endpoint = {}
self._runtime_states = {}
self._event_bus = event_bus
self._hosts_lock = RLock()
self._tablets = Tablets({})

def set_event_bus(self, event_bus):
self._event_bus = event_bus
with self._hosts_lock:
for runtime_state in self._runtime_states.values():
runtime_state.set_event_bus(event_bus)

def export_schema_as_string(self):
"""
Returns a string that can be executed as a query in order to recreate
Expand Down Expand Up @@ -340,21 +349,30 @@ def add_or_return_host(self, host):
try:
return self._hosts[host.host_id], False
except KeyError:
host = self._bind_runtime_state(host)
self._host_id_by_endpoint[host.endpoint] = host.host_id
self._hosts[host.host_id] = host
return host, True

def remove_host(self, host):
self._tablets.drop_tablets_by_host_id(host.host_id)
with self._hosts_lock:
current_host = self._hosts.get(host.host_id)
self._host_id_by_endpoint.pop(host.endpoint, False)
if current_host is not None:
self._host_id_by_endpoint.pop(current_host.endpoint, False)
self._runtime_states.pop(host.host_id, None)
return bool(self._hosts.pop(host.host_id, False))

def remove_host_by_host_id(self, host_id, endpoint=None):
self._tablets.drop_tablets_by_host_id(host_id)
with self._hosts_lock:
if endpoint and self._host_id_by_endpoint[endpoint] == host_id:
current_host = self._hosts.get(host_id)
if endpoint and self._host_id_by_endpoint.get(endpoint) == host_id:
self._host_id_by_endpoint.pop(endpoint, False)
if current_host is not None:
self._host_id_by_endpoint.pop(current_host.endpoint, False)
self._runtime_states.pop(host_id, None)
return bool(self._hosts.pop(host_id, False))

def update_host(self, host, old_endpoint):
Expand All @@ -363,6 +381,65 @@ def update_host(self, host, old_endpoint):
self._host_id_by_endpoint.pop(old_endpoint, False)
self._host_id_by_endpoint[host.endpoint] = host.host_id

def replace_host(self, host_id, source=None, **fields):
"""
Replace a Host topology snapshot for host_id and publish HOST_CHANGED.
"""
with self._hosts_lock:
old_host = self._hosts.get(host_id)
if old_host is None:
return None, ()

changed_fields = []
old_values = {}
new_values = {}

for field, new_value in fields.items():
old_value = getattr(old_host, field)
if old_value != new_value:
changed_fields.append(field)
old_values[field] = old_value
new_values[field] = new_value

if not changed_fields:
return old_host, ()

copy_kwargs = dict((field, fields[field]) for field in changed_fields)
new_host = old_host.copy_with(**copy_kwargs)
new_host.runtime_state.set_event_bus(self._event_bus)
new_host.runtime_state.bind_host(new_host)

self._hosts[host_id] = new_host
if "endpoint" in changed_fields:
if self._host_id_by_endpoint.get(old_host.endpoint) == host_id:
self._host_id_by_endpoint.pop(old_host.endpoint, False)
self._host_id_by_endpoint[new_host.endpoint] = host_id
else:
self._host_id_by_endpoint[new_host.endpoint] = host_id

payload = HostEventPayload(
host_id=host_id,
old_host=old_host,
new_host=new_host,
changed_fields=tuple(changed_fields),
old_values=old_values,
new_values=new_values)
if self._event_bus:
self._event_bus.publish(DriverEvent(HOST_CHANGED, HOST, payload, source or self))
return new_host, tuple(changed_fields)

def _bind_runtime_state(self, host):
runtime_state = self._runtime_states.get(host.host_id)
if runtime_state is None:
runtime_state = host.runtime_state
self._runtime_states[host.host_id] = runtime_state
elif host.runtime_state is not runtime_state:
host = host.copy_with(runtime_state=runtime_state)

runtime_state.set_event_bus(self._event_bus)
runtime_state.bind_host(host)
return host

def get_host(self, endpoint_or_address, port=None):
"""
Find a host in the metadata for a specific endpoint. If a string inet address and port are passed,
Expand Down
Loading
Loading