3737from cassandra import SignatureDescriptor , ConsistencyLevel , InvalidRequest , Unauthorized
3838import cassandra .cqltypes as types
3939from cassandra .encoder import Encoder
40+ from cassandra .events import DriverEvent , HOST , HOST_CHANGED , HostEventPayload
4041from cassandra .marshal import varint_unpack
4142from cassandra .protocol import QueryMessage
4243from cassandra .query import dict_factory , bind_params
@@ -121,14 +122,22 @@ class Metadata(object):
121122 dbaas = False
122123 """ A boolean indicating if connected to a DBaaS cluster """
123124
124- def __init__ (self ):
125+ def __init__ (self , event_bus = None ):
125126 self .keyspaces = {}
126127 self .dbaas = False
127128 self ._hosts = {}
128129 self ._host_id_by_endpoint = {}
130+ self ._runtime_states = {}
131+ self ._event_bus = event_bus
129132 self ._hosts_lock = RLock ()
130133 self ._tablets = Tablets ({})
131134
135+ def set_event_bus (self , event_bus ):
136+ self ._event_bus = event_bus
137+ with self ._hosts_lock :
138+ for runtime_state in self ._runtime_states .values ():
139+ runtime_state .set_event_bus (event_bus )
140+
132141 def export_schema_as_string (self ):
133142 """
134143 Returns a string that can be executed as a query in order to recreate
@@ -340,21 +349,30 @@ def add_or_return_host(self, host):
340349 try :
341350 return self ._hosts [host .host_id ], False
342351 except KeyError :
352+ host = self ._bind_runtime_state (host )
343353 self ._host_id_by_endpoint [host .endpoint ] = host .host_id
344354 self ._hosts [host .host_id ] = host
345355 return host , True
346356
347357 def remove_host (self , host ):
348358 self ._tablets .drop_tablets_by_host_id (host .host_id )
349359 with self ._hosts_lock :
360+ current_host = self ._hosts .get (host .host_id )
350361 self ._host_id_by_endpoint .pop (host .endpoint , False )
362+ if current_host is not None :
363+ self ._host_id_by_endpoint .pop (current_host .endpoint , False )
364+ self ._runtime_states .pop (host .host_id , None )
351365 return bool (self ._hosts .pop (host .host_id , False ))
352366
353367 def remove_host_by_host_id (self , host_id , endpoint = None ):
354368 self ._tablets .drop_tablets_by_host_id (host_id )
355369 with self ._hosts_lock :
356- if endpoint and self ._host_id_by_endpoint [endpoint ] == host_id :
370+ current_host = self ._hosts .get (host_id )
371+ if endpoint and self ._host_id_by_endpoint .get (endpoint ) == host_id :
357372 self ._host_id_by_endpoint .pop (endpoint , False )
373+ if current_host is not None :
374+ self ._host_id_by_endpoint .pop (current_host .endpoint , False )
375+ self ._runtime_states .pop (host_id , None )
358376 return bool (self ._hosts .pop (host_id , False ))
359377
360378 def update_host (self , host , old_endpoint ):
@@ -363,6 +381,65 @@ def update_host(self, host, old_endpoint):
363381 self ._host_id_by_endpoint .pop (old_endpoint , False )
364382 self ._host_id_by_endpoint [host .endpoint ] = host .host_id
365383
384+ def replace_host (self , host_id , source = None , ** fields ):
385+ """
386+ Replace a Host topology snapshot for host_id and publish HOST_CHANGED.
387+ """
388+ with self ._hosts_lock :
389+ old_host = self ._hosts .get (host_id )
390+ if old_host is None :
391+ return None , ()
392+
393+ changed_fields = []
394+ old_values = {}
395+ new_values = {}
396+
397+ for field , new_value in fields .items ():
398+ old_value = getattr (old_host , field )
399+ if old_value != new_value :
400+ changed_fields .append (field )
401+ old_values [field ] = old_value
402+ new_values [field ] = new_value
403+
404+ if not changed_fields :
405+ return old_host , ()
406+
407+ copy_kwargs = dict ((field , fields [field ]) for field in changed_fields )
408+ new_host = old_host .copy_with (** copy_kwargs )
409+ new_host .runtime_state .set_event_bus (self ._event_bus )
410+ new_host .runtime_state .bind_host (new_host )
411+
412+ self ._hosts [host_id ] = new_host
413+ if "endpoint" in changed_fields :
414+ if self ._host_id_by_endpoint .get (old_host .endpoint ) == host_id :
415+ self ._host_id_by_endpoint .pop (old_host .endpoint , False )
416+ self ._host_id_by_endpoint [new_host .endpoint ] = host_id
417+ else :
418+ self ._host_id_by_endpoint [new_host .endpoint ] = host_id
419+
420+ payload = HostEventPayload (
421+ host_id = host_id ,
422+ old_host = old_host ,
423+ new_host = new_host ,
424+ changed_fields = tuple (changed_fields ),
425+ old_values = old_values ,
426+ new_values = new_values )
427+ if self ._event_bus :
428+ self ._event_bus .publish (DriverEvent (HOST_CHANGED , HOST , payload , source or self ))
429+ return new_host , tuple (changed_fields )
430+
431+ def _bind_runtime_state (self , host ):
432+ runtime_state = self ._runtime_states .get (host .host_id )
433+ if runtime_state is None :
434+ runtime_state = host .runtime_state
435+ self ._runtime_states [host .host_id ] = runtime_state
436+ elif host .runtime_state is not runtime_state :
437+ host = host .copy_with (runtime_state = runtime_state )
438+
439+ runtime_state .set_event_bus (self ._event_bus )
440+ runtime_state .bind_host (host )
441+ return host
442+
366443 def get_host (self , endpoint_or_address , port = None ):
367444 """
368445 Find a host in the metadata for a specific endpoint. If a string inet address and port are passed,
0 commit comments