@@ -429,6 +429,8 @@ def __init__(self, local_dc, local_rack, used_hosts_per_remote_dc=0):
429429 self .used_hosts_per_remote_dc = used_hosts_per_remote_dc
430430 self ._live_hosts = {}
431431 self ._dc_live_hosts = {}
432+ self ._remote_hosts = {}
433+ self ._non_local_rack_hosts = []
432434 self ._endpoints = []
433435 self ._position = 0
434436 LoadBalancingPolicy .__init__ (self )
@@ -439,78 +441,89 @@ def _rack(self, host):
439441 def _dc (self , host ):
440442 return host .datacenter or self .local_dc
441443
444+ def _refresh_remote_hosts (self ):
445+ # Using dict.fromkeys() instead of a set to preserve insertion order (Python 3.7+)
446+ # while still providing O(1) lookup for `host in self._remote_hosts`.
447+ remote_hosts = {}
448+ if self .used_hosts_per_remote_dc > 0 :
449+ for datacenter , hosts in self ._dc_live_hosts .items ():
450+ if datacenter != self .local_dc :
451+ remote_hosts .update (
452+ dict .fromkeys (hosts [:self .used_hosts_per_remote_dc ])
453+ )
454+ self ._remote_hosts = remote_hosts
455+
456+ def _refresh_non_local_rack_hosts (self ):
457+ local_live = self ._dc_live_hosts .get (self .local_dc , ())
458+ self ._non_local_rack_hosts = [
459+ h for h in local_live if self ._rack (h ) != self .local_rack
460+ ]
461+
442462 def populate (self , cluster , hosts ):
443463 for (dc , rack ), rack_hosts in groupby (hosts , lambda host : (self ._dc (host ), self ._rack (host ))):
444464 self ._live_hosts [(dc , rack )] = tuple ({* rack_hosts , * self ._live_hosts .get ((dc , rack ), [])})
445465 for dc , dc_hosts in groupby (hosts , lambda host : self ._dc (host )):
446466 self ._dc_live_hosts [dc ] = tuple ({* dc_hosts , * self ._dc_live_hosts .get (dc , [])})
447467
448468 self ._position = randint (0 , len (hosts ) - 1 ) if hosts else 0
469+ self ._refresh_remote_hosts ()
470+ self ._refresh_non_local_rack_hosts ()
449471
450472 def distance (self , host ):
451- rack = self ._rack (host )
452473 dc = self ._dc (host )
453- if rack == self .local_rack and dc == self .local_dc :
454- return HostDistance .LOCAL_RACK
455-
456474 if dc == self .local_dc :
475+ if self ._rack (host ) == self .local_rack :
476+ return HostDistance .LOCAL_RACK
457477 return HostDistance .LOCAL
458478
459- if not self .used_hosts_per_remote_dc :
460- return HostDistance .IGNORED
461-
462- dc_hosts = self ._dc_live_hosts .get (dc , ())
463- if not dc_hosts :
464- return HostDistance .IGNORED
465- if host in dc_hosts and dc_hosts .index (host ) < self .used_hosts_per_remote_dc :
479+ remote_hosts = self ._remote_hosts
480+ if host in remote_hosts :
466481 return HostDistance .REMOTE
467- else :
468- return HostDistance .IGNORED
482+ return HostDistance .IGNORED
469483
470484 def make_query_plan (self , working_keyspace = None , query = None ):
471485 pos = self ._position
472486 self ._position += 1
473487
474488 local_rack_live = self ._live_hosts .get ((self .local_dc , self .local_rack ), ())
475- pos = ( pos % len (local_rack_live )) if local_rack_live else 0
476- # Slice the cyclic iterator to start from pos and include the next len(local_live) elements
477- # This ensures we get exactly one full cycle starting from pos
478- for host in islice ( cycle ( local_rack_live ), pos , pos + len ( local_rack_live ) ):
479- yield host
489+ length = len (local_rack_live )
490+ if length :
491+ p = pos % length
492+ for i in range ( length ):
493+ yield local_rack_live [( p + i ) % length ]
480494
481- local_live = [host for host in self ._dc_live_hosts .get (self .local_dc , ()) if host .rack != self .local_rack ]
482- pos = (pos % len (local_live )) if local_live else 0
483- for host in islice (cycle (local_live ), pos , pos + len (local_live )):
484- yield host
495+ local_non_rack = self ._non_local_rack_hosts
496+ length = len (local_non_rack )
497+ if length :
498+ p = pos % length
499+ for i in range (length ):
500+ yield local_non_rack [(p + i ) % length ]
485501
486- # the dict can change, so get candidate DCs iterating over keys of a copy
487- for dc , remote_live in self ._dc_live_hosts .copy ().items ():
488- if dc != self .local_dc :
489- for host in remote_live [:self .used_hosts_per_remote_dc ]:
490- yield host
502+ remote_hosts = self ._remote_hosts
503+ for host in remote_hosts :
504+ yield host
491505
492506 def on_up (self , host ):
493507 dc = self ._dc (host )
494508 rack = self ._rack (host )
495509 with self ._hosts_lock :
496- current_rack_hosts = self ._live_hosts .get ((dc , rack ), ())
497- if host not in current_rack_hosts :
498- self ._live_hosts [(dc , rack )] = current_rack_hosts + (host , )
499510 current_dc_hosts = self ._dc_live_hosts .get (dc , ())
500511 if host not in current_dc_hosts :
501512 self ._dc_live_hosts [dc ] = current_dc_hosts + (host , )
502513
514+ if dc != self .local_dc :
515+ self ._refresh_remote_hosts ()
516+ else :
517+ self ._refresh_non_local_rack_hosts ()
518+
519+ current_rack_hosts = self ._live_hosts .get ((dc , rack ), ())
520+ if host not in current_rack_hosts :
521+ self ._live_hosts [(dc , rack )] = current_rack_hosts + (host , )
522+
503523 def on_down (self , host ):
504524 dc = self ._dc (host )
505525 rack = self ._rack (host )
506526 with self ._hosts_lock :
507- current_rack_hosts = self ._live_hosts .get ((dc , rack ), ())
508- if host in current_rack_hosts :
509- hosts = tuple (h for h in current_rack_hosts if h != host )
510- if hosts :
511- self ._live_hosts [(dc , rack )] = hosts
512- else :
513- del self ._live_hosts [(dc , rack )]
514527 current_dc_hosts = self ._dc_live_hosts .get (dc , ())
515528 if host in current_dc_hosts :
516529 hosts = tuple (h for h in current_dc_hosts if h != host )
@@ -519,6 +532,19 @@ def on_down(self, host):
519532 else :
520533 del self ._dc_live_hosts [dc ]
521534
535+ if dc != self .local_dc :
536+ self ._refresh_remote_hosts ()
537+ else :
538+ self ._refresh_non_local_rack_hosts ()
539+
540+ current_rack_hosts = self ._live_hosts .get ((dc , rack ), ())
541+ if host in current_rack_hosts :
542+ hosts = tuple (h for h in current_rack_hosts if h != host )
543+ if hosts :
544+ self ._live_hosts [(dc , rack )] = hosts
545+ else :
546+ del self ._live_hosts [(dc , rack )]
547+
522548 def on_add (self , host ):
523549 self .on_up (host )
524550
0 commit comments