@@ -244,34 +244,81 @@ def __init__(self, local_dc='', used_hosts_per_remote_dc=0):
244244 self .local_dc = local_dc
245245 self .used_hosts_per_remote_dc = used_hosts_per_remote_dc
246246 self ._dc_live_hosts = {}
247+ self ._hosts_by_distance = {
248+ HostDistance .LOCAL : {},
249+ HostDistance .REMOTE : {}
250+ }
247251 self ._position = 0
248252 LoadBalancingPolicy .__init__ (self )
249253
250254 def _dc (self , host ):
251255 return host .datacenter or self .local_dc
252256
257+ def _refresh_distances (self , dc = None ):
258+ # Always called under self._hosts_lock, so it's safe to iterate over self._dc_live_hosts
259+
260+ if dc is None :
261+ # Full rebuild
262+ local_hosts = dict .fromkeys (self ._dc_live_hosts .get (self .local_dc , ()))
263+ remote_hosts = {}
264+ if self .used_hosts_per_remote_dc > 0 :
265+ for datacenter , hosts in self ._dc_live_hosts .items ():
266+ if datacenter != self .local_dc :
267+ remote_hosts .update (dict .fromkeys (hosts [:self .used_hosts_per_remote_dc ]))
268+ elif dc == self .local_dc :
269+ # Rebuild local, reuse remote
270+ local_hosts = dict .fromkeys (self ._dc_live_hosts .get (self .local_dc , ()))
271+ remote_hosts = self ._hosts_by_distance [HostDistance .REMOTE ]
272+ else :
273+ # Rebuild remote, reuse local
274+ local_hosts = self ._hosts_by_distance [HostDistance .LOCAL ]
275+ remote_hosts = {}
276+ if self .used_hosts_per_remote_dc > 0 :
277+ for datacenter , hosts in self ._dc_live_hosts .items ():
278+ if datacenter != self .local_dc :
279+ remote_hosts .update (dict .fromkeys (hosts [:self .used_hosts_per_remote_dc ]))
280+
281+ self ._hosts_by_distance = {
282+ HostDistance .LOCAL : local_hosts ,
283+ HostDistance .REMOTE : remote_hosts
284+ }
285+
286+ def _add_local_distance (self , host ):
287+ distances = self ._hosts_by_distance
288+ if host not in distances [HostDistance .LOCAL ]:
289+ new_local = distances [HostDistance .LOCAL ].copy ()
290+ new_local [host ] = None
291+ self ._hosts_by_distance = {
292+ HostDistance .LOCAL : new_local ,
293+ HostDistance .REMOTE : distances [HostDistance .REMOTE ]
294+ }
295+
296+ def _remove_local_distance (self , host ):
297+ distances = self ._hosts_by_distance
298+ if host in distances [HostDistance .LOCAL ]:
299+ new_local = distances [HostDistance .LOCAL ].copy ()
300+ del new_local [host ]
301+ self ._hosts_by_distance = {
302+ HostDistance .LOCAL : new_local ,
303+ HostDistance .REMOTE : distances [HostDistance .REMOTE ]
304+ }
305+
253306 def populate (self , cluster , hosts ):
254307 for dc , dc_hosts in groupby (hosts , lambda h : self ._dc (h )):
255308 self ._dc_live_hosts [dc ] = tuple ({* dc_hosts , * self ._dc_live_hosts .get (dc , [])})
256309
257310 self ._position = randint (0 , len (hosts ) - 1 ) if hosts else 0
311+ self ._refresh_distances ()
258312
259313 def distance (self , host ):
260- dc = self ._dc (host )
261- if dc == self .local_dc :
314+ # Capture the reference to the dict to ensure we check a consistent snapshot
315+ # (avoiding the case where we check LOCAL, fail, context switch, update happens, check REMOTE, fail)
316+ distances = self ._hosts_by_distance
317+ if host in distances [HostDistance .LOCAL ]:
262318 return HostDistance .LOCAL
263-
264- if not self .used_hosts_per_remote_dc :
265- return HostDistance .IGNORED
266- else :
267- dc_hosts = self ._dc_live_hosts .get (dc )
268- if not dc_hosts :
269- return HostDistance .IGNORED
270-
271- if host in list (dc_hosts )[:self .used_hosts_per_remote_dc ]:
272- return HostDistance .REMOTE
273- else :
274- return HostDistance .IGNORED
319+ elif host in distances [HostDistance .REMOTE ]:
320+ return HostDistance .REMOTE
321+ return HostDistance .IGNORED
275322
276323 def make_query_plan (self , working_keyspace = None , query = None ):
277324 # not thread-safe, but we don't care much about lost increments
@@ -284,29 +331,34 @@ def make_query_plan(self, working_keyspace=None, query=None):
284331 for host in islice (cycle (local_live ), pos , pos + len (local_live )):
285332 yield host
286333
287- # the dict can change, so get candidate DCs iterating over keys of a copy
288- other_dcs = [dc for dc in self ._dc_live_hosts .copy ().keys () if dc != self .local_dc ]
289- for dc in other_dcs :
290- remote_live = self ._dc_live_hosts .get (dc , ())
291- for host in remote_live [:self .used_hosts_per_remote_dc ]:
292- yield host
334+ for host in self ._hosts_by_distance [HostDistance .REMOTE ]:
335+ yield host
293336
294337 def on_up (self , host ):
295338 # not worrying about threads because this will happen during
296339 # control connection startup/refresh
340+ refresh_all = False
297341 if not self .local_dc and host .datacenter :
298342 self .local_dc = host .datacenter
299343 log .info ("Using datacenter '%s' for DCAwareRoundRobinPolicy (via host '%s'); "
300344 "if incorrect, please specify a local_dc to the constructor, "
301345 "or limit contact points to local cluster nodes" %
302346 (self .local_dc , host .endpoint ))
347+ refresh_all = True
303348
304349 dc = self ._dc (host )
305350 with self ._hosts_lock :
306351 current_hosts = self ._dc_live_hosts .get (dc , ())
307352 if host not in current_hosts :
308353 self ._dc_live_hosts [dc ] = current_hosts + (host , )
309354
355+ if refresh_all :
356+ self ._refresh_distances ()
357+ elif dc == self .local_dc :
358+ self ._add_local_distance (host )
359+ else :
360+ self ._refresh_distances (dc )
361+
310362 def on_down (self , host ):
311363 dc = self ._dc (host )
312364 with self ._hosts_lock :
@@ -318,6 +370,11 @@ def on_down(self, host):
318370 else :
319371 del self ._dc_live_hosts [dc ]
320372
373+ if dc == self .local_dc :
374+ self ._remove_local_distance (host )
375+ else :
376+ self ._refresh_distances (dc )
377+
321378 def on_add (self , host ):
322379 self .on_up (host )
323380
0 commit comments