Skip to content

Commit 66bb094

Browse files
committed
host-events: finalize Host event refactor
Finalize HOST_ADDED delivery after host setup, route topology replacements through HOST_CHANGED, and keep pools keyed by host_id. Add regression coverage for listener timing, immutable Host snapshots, and pool rebinding.
1 parent e6f9e9f commit 66bb094

11 files changed

Lines changed: 1343 additions & 210 deletions

cassandra/cluster.py

Lines changed: 242 additions & 59 deletions
Large diffs are not rendered by default.

cassandra/events.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# Copyright DataStax, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
Internal driver event primitives.
17+
18+
This module intentionally does not expose a public subscription API. It is a
19+
small synchronous bus used to decouple driver subsystems that need to react to
20+
shared internal state changes.
21+
"""
22+
23+
from collections import defaultdict
24+
import logging
25+
from threading import RLock
26+
27+
28+
log = logging.getLogger(__name__)
29+
30+
31+
HOST = "HOST"
32+
33+
HOST_ADDED = "HOST_ADDED"
34+
HOST_REMOVED = "HOST_REMOVED"
35+
HOST_UP = "HOST_UP"
36+
HOST_DOWN = "HOST_DOWN"
37+
HOST_CHANGED = "HOST_CHANGED"
38+
39+
40+
class DriverEvent(object):
41+
"""
42+
Internal event envelope.
43+
"""
44+
45+
__slots__ = ("type", "category", "payload", "source")
46+
47+
def __init__(self, event_type, category, payload=None, source=None):
48+
self.type = event_type
49+
self.category = category
50+
self.payload = payload
51+
self.source = source
52+
53+
def __repr__(self):
54+
return "%s(type=%r, category=%r, payload=%r, source=%r)" % (
55+
self.__class__.__name__, self.type, self.category, self.payload, self.source)
56+
57+
58+
class HostEventPayload(object):
59+
"""
60+
Payload for host topology and runtime-state events.
61+
"""
62+
63+
__slots__ = (
64+
"host_id", "host", "old_host", "new_host", "changed_fields",
65+
"old_values", "new_values", "refresh_nodes")
66+
67+
def __init__(self, host=None, host_id=None, old_host=None, new_host=None,
68+
changed_fields=(), old_values=None, new_values=None,
69+
refresh_nodes=True):
70+
if host is None:
71+
host = new_host if new_host is not None else old_host
72+
73+
if host_id is None and host is not None:
74+
host_id = host.host_id
75+
76+
self.host_id = host_id
77+
self.host = host
78+
self.old_host = old_host
79+
self.new_host = new_host
80+
self.changed_fields = tuple(changed_fields or ())
81+
self.old_values = old_values or {}
82+
self.new_values = new_values or {}
83+
self.refresh_nodes = refresh_nodes
84+
85+
def __repr__(self):
86+
return ("%s(host_id=%r, host=%r, old_host=%r, new_host=%r, "
87+
"changed_fields=%r, old_values=%r, new_values=%r)") % (
88+
self.__class__.__name__, self.host_id, self.host,
89+
self.old_host, self.new_host, self.changed_fields,
90+
self.old_values, self.new_values)
91+
92+
93+
class _EventBus(object):
94+
"""
95+
Synchronous internal event bus.
96+
"""
97+
98+
def __init__(self):
99+
self._type_subscribers = defaultdict(list)
100+
self._category_subscribers = defaultdict(list)
101+
self._lock = RLock()
102+
103+
def subscribe(self, event_type, handler):
104+
with self._lock:
105+
handlers = self._type_subscribers[event_type]
106+
if handler not in handlers:
107+
handlers.append(handler)
108+
109+
def unsubscribe(self, event_type, handler):
110+
with self._lock:
111+
self._remove_handler(self._type_subscribers.get(event_type), handler)
112+
113+
def subscribe_category(self, category, handler):
114+
with self._lock:
115+
handlers = self._category_subscribers[category]
116+
if handler not in handlers:
117+
handlers.append(handler)
118+
119+
def unsubscribe_category(self, category, handler):
120+
with self._lock:
121+
self._remove_handler(self._category_subscribers.get(category), handler)
122+
123+
def publish(self, event):
124+
handlers = self._handlers_for_event(event)
125+
for handler in handlers:
126+
try:
127+
handler(event)
128+
except Exception:
129+
log.exception("Error dispatching driver event %s to %r", event.type, handler)
130+
return event
131+
132+
@staticmethod
133+
def _remove_handler(handlers, handler):
134+
if not handlers:
135+
return
136+
try:
137+
handlers.remove(handler)
138+
except ValueError:
139+
pass
140+
141+
def _handlers_for_event(self, event):
142+
with self._lock:
143+
raw_handlers = list(self._type_subscribers.get(event.type, ()))
144+
raw_handlers.extend(self._category_subscribers.get(event.category, ()))
145+
146+
handlers = []
147+
for handler in raw_handlers:
148+
if handler not in handlers:
149+
handlers.append(handler)
150+
return handlers

cassandra/metadata.py

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from cassandra import SignatureDescriptor, ConsistencyLevel, InvalidRequest, Unauthorized
3838
import cassandra.cqltypes as types
3939
from cassandra.encoder import Encoder
40+
from cassandra.events import DriverEvent, HOST, HOST_CHANGED, HostEventPayload
4041
from cassandra.marshal import varint_unpack
4142
from cassandra.protocol import QueryMessage
4243
from cassandra.query import dict_factory, bind_params
@@ -121,14 +122,22 @@ class Metadata(object):
121122
dbaas = False
122123
""" A boolean indicating if connected to a DBaaS cluster """
123124

124-
def __init__(self):
125+
def __init__(self, event_bus=None):
125126
self.keyspaces = {}
126127
self.dbaas = False
127128
self._hosts = {}
128129
self._host_id_by_endpoint = {}
130+
self._runtime_states = {}
131+
self._event_bus = event_bus
129132
self._hosts_lock = RLock()
130133
self._tablets = Tablets({})
131134

135+
def set_event_bus(self, event_bus):
136+
self._event_bus = event_bus
137+
with self._hosts_lock:
138+
for runtime_state in self._runtime_states.values():
139+
runtime_state.set_event_bus(event_bus)
140+
132141
def export_schema_as_string(self):
133142
"""
134143
Returns a string that can be executed as a query in order to recreate
@@ -340,21 +349,30 @@ def add_or_return_host(self, host):
340349
try:
341350
return self._hosts[host.host_id], False
342351
except KeyError:
352+
host = self._bind_runtime_state(host)
343353
self._host_id_by_endpoint[host.endpoint] = host.host_id
344354
self._hosts[host.host_id] = host
345355
return host, True
346356

347357
def remove_host(self, host):
348358
self._tablets.drop_tablets_by_host_id(host.host_id)
349359
with self._hosts_lock:
360+
current_host = self._hosts.get(host.host_id)
350361
self._host_id_by_endpoint.pop(host.endpoint, False)
362+
if current_host is not None:
363+
self._host_id_by_endpoint.pop(current_host.endpoint, False)
364+
self._runtime_states.pop(host.host_id, None)
351365
return bool(self._hosts.pop(host.host_id, False))
352366

353367
def remove_host_by_host_id(self, host_id, endpoint=None):
354368
self._tablets.drop_tablets_by_host_id(host_id)
355369
with self._hosts_lock:
356-
if endpoint and self._host_id_by_endpoint[endpoint] == host_id:
370+
current_host = self._hosts.get(host_id)
371+
if endpoint and self._host_id_by_endpoint.get(endpoint) == host_id:
357372
self._host_id_by_endpoint.pop(endpoint, False)
373+
if current_host is not None:
374+
self._host_id_by_endpoint.pop(current_host.endpoint, False)
375+
self._runtime_states.pop(host_id, None)
358376
return bool(self._hosts.pop(host_id, False))
359377

360378
def update_host(self, host, old_endpoint):
@@ -363,6 +381,65 @@ def update_host(self, host, old_endpoint):
363381
self._host_id_by_endpoint.pop(old_endpoint, False)
364382
self._host_id_by_endpoint[host.endpoint] = host.host_id
365383

384+
def replace_host(self, host_id, source=None, **fields):
385+
"""
386+
Replace a Host topology snapshot for host_id and publish HOST_CHANGED.
387+
"""
388+
with self._hosts_lock:
389+
old_host = self._hosts.get(host_id)
390+
if old_host is None:
391+
return None, ()
392+
393+
changed_fields = []
394+
old_values = {}
395+
new_values = {}
396+
397+
for field, new_value in fields.items():
398+
old_value = getattr(old_host, field)
399+
if old_value != new_value:
400+
changed_fields.append(field)
401+
old_values[field] = old_value
402+
new_values[field] = new_value
403+
404+
if not changed_fields:
405+
return old_host, ()
406+
407+
copy_kwargs = dict((field, fields[field]) for field in changed_fields)
408+
new_host = old_host.copy_with(**copy_kwargs)
409+
new_host.runtime_state.set_event_bus(self._event_bus)
410+
new_host.runtime_state.bind_host(new_host)
411+
412+
self._hosts[host_id] = new_host
413+
if "endpoint" in changed_fields:
414+
if self._host_id_by_endpoint.get(old_host.endpoint) == host_id:
415+
self._host_id_by_endpoint.pop(old_host.endpoint, False)
416+
self._host_id_by_endpoint[new_host.endpoint] = host_id
417+
else:
418+
self._host_id_by_endpoint[new_host.endpoint] = host_id
419+
420+
payload = HostEventPayload(
421+
host_id=host_id,
422+
old_host=old_host,
423+
new_host=new_host,
424+
changed_fields=tuple(changed_fields),
425+
old_values=old_values,
426+
new_values=new_values)
427+
if self._event_bus:
428+
self._event_bus.publish(DriverEvent(HOST_CHANGED, HOST, payload, source or self))
429+
return new_host, tuple(changed_fields)
430+
431+
def _bind_runtime_state(self, host):
432+
runtime_state = self._runtime_states.get(host.host_id)
433+
if runtime_state is None:
434+
runtime_state = host.runtime_state
435+
self._runtime_states[host.host_id] = runtime_state
436+
elif host.runtime_state is not runtime_state:
437+
host = host.copy_with(runtime_state=runtime_state)
438+
439+
runtime_state.set_event_bus(self._event_bus)
440+
runtime_state.bind_host(host)
441+
return host
442+
366443
def get_host(self, endpoint_or_address, port=None):
367444
"""
368445
Find a host in the metadata for a specific endpoint. If a string inet address and port are passed,

0 commit comments

Comments
 (0)