1313# limitations under the License.
1414import random
1515
16- from collections import namedtuple
16+ from collections import namedtuple , OrderedDict
1717from itertools import islice , cycle , groupby , repeat
1818import logging
1919from random import randint , shuffle
@@ -244,34 +244,41 @@ def __init__(self, local_dc='', used_hosts_per_remote_dc=0):
244244 self .local_dc = local_dc
245245 self .used_hosts_per_remote_dc = used_hosts_per_remote_dc
246246 self ._dc_live_hosts = {}
247+ self ._remote_hosts = {}
247248 self ._position = 0
248249 LoadBalancingPolicy .__init__ (self )
249250
250251 def _dc (self , host ):
251252 return host .datacenter or self .local_dc
252253
254+ def _refresh_remote_hosts (self ):
255+ # Using dict.fromkeys() instead of a set to preserve insertion order (Python 3.7+)
256+ # while still providing O(1) lookup for `host in self._remote_hosts`.
257+ remote_hosts = {}
258+ if self .used_hosts_per_remote_dc > 0 :
259+ for datacenter , hosts in self ._dc_live_hosts .items ():
260+ if datacenter != self .local_dc :
261+ remote_hosts .update (
262+ dict .fromkeys (hosts [:self .used_hosts_per_remote_dc ])
263+ )
264+ self ._remote_hosts = remote_hosts
265+
253266 def populate (self , cluster , hosts ):
254267 for dc , dc_hosts in groupby (hosts , lambda h : self ._dc (h )):
255268 self ._dc_live_hosts [dc ] = tuple ({* dc_hosts , * self ._dc_live_hosts .get (dc , [])})
256269
257270 self ._position = randint (0 , len (hosts ) - 1 ) if hosts else 0
271+ self ._refresh_remote_hosts ()
258272
259273 def distance (self , host ):
260274 dc = self ._dc (host )
261275 if dc == self .local_dc :
262276 return HostDistance .LOCAL
263277
264- if not self .used_hosts_per_remote_dc :
265- return HostDistance .IGNORED
266- else :
267- dc_hosts = self ._dc_live_hosts .get (dc )
268- if not dc_hosts :
269- return HostDistance .IGNORED
270-
271- if host in list (dc_hosts )[:self .used_hosts_per_remote_dc ]:
272- return HostDistance .REMOTE
273- else :
274- return HostDistance .IGNORED
278+ remote_hosts = self ._remote_hosts
279+ if host in remote_hosts :
280+ return HostDistance .REMOTE
281+ return HostDistance .IGNORED
275282
276283 def make_query_plan (self , working_keyspace = None , query = None ):
277284 # not thread-safe, but we don't care much about lost increments
@@ -280,32 +287,38 @@ def make_query_plan(self, working_keyspace=None, query=None):
280287 self ._position += 1
281288
282289 local_live = self ._dc_live_hosts .get (self .local_dc , ())
283- pos = (pos % len (local_live )) if local_live else 0
284- for host in islice (cycle (local_live ), pos , pos + len (local_live )):
285- yield host
290+ length = len (local_live )
291+ if length :
292+ pos %= length
293+ for i in range (length ):
294+ yield local_live [(pos + i ) % length ]
286295
287- # the dict can change, so get candidate DCs iterating over keys of a copy
288- other_dcs = [dc for dc in self ._dc_live_hosts .copy ().keys () if dc != self .local_dc ]
289- for dc in other_dcs :
290- remote_live = self ._dc_live_hosts .get (dc , ())
291- for host in remote_live [:self .used_hosts_per_remote_dc ]:
292- yield host
296+ remote_hosts = self ._remote_hosts
297+ for host in remote_hosts :
298+ yield host
293299
294300 def on_up (self , host ):
295301 # not worrying about threads because this will happen during
296302 # control connection startup/refresh
303+ refresh_remote = False
297304 if not self .local_dc and host .datacenter :
298305 self .local_dc = host .datacenter
299306 log .info ("Using datacenter '%s' for DCAwareRoundRobinPolicy (via host '%s'); "
300307 "if incorrect, please specify a local_dc to the constructor, "
301308 "or limit contact points to local cluster nodes" %
302309 (self .local_dc , host .endpoint ))
310+ refresh_remote = True
303311
304312 dc = self ._dc (host )
305313 with self ._hosts_lock :
306314 current_hosts = self ._dc_live_hosts .get (dc , ())
307315 if host not in current_hosts :
308316 self ._dc_live_hosts [dc ] = current_hosts + (host , )
317+ if dc != self .local_dc :
318+ refresh_remote = True
319+
320+ if refresh_remote :
321+ self ._refresh_remote_hosts ()
309322
310323 def on_down (self , host ):
311324 dc = self ._dc (host )
@@ -318,6 +331,9 @@ def on_down(self, host):
318331 else :
319332 del self ._dc_live_hosts [dc ]
320333
334+ if dc != self .local_dc :
335+ self ._refresh_remote_hosts ()
336+
321337 def on_add (self , host ):
322338 self .on_up (host )
323339
0 commit comments