Skip to content

Commit 9f34e9e

Browse files
committed
perf: optimize RackAwareRoundRobinPolicy with cached host distances
Cache remote hosts and non-local-rack hosts to enable O(1) distance lookups. Replace islice(cycle(...)) with index arithmetic. Reorder on_up/on_down to update DC-level hosts before rack-level for correct cache invalidation.
1 parent bfdcfca commit 9f34e9e

1 file changed

Lines changed: 63 additions & 37 deletions

File tree

cassandra/policies.py

Lines changed: 63 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,8 @@ def __init__(self, local_dc, local_rack, used_hosts_per_remote_dc=0):
429429
self.used_hosts_per_remote_dc = used_hosts_per_remote_dc
430430
self._live_hosts = {}
431431
self._dc_live_hosts = {}
432+
self._remote_hosts = {}
433+
self._non_local_rack_hosts = []
432434
self._endpoints = []
433435
self._position = 0
434436
LoadBalancingPolicy.__init__(self)
@@ -439,78 +441,89 @@ def _rack(self, host):
439441
def _dc(self, host):
440442
return host.datacenter or self.local_dc
441443

444+
def _refresh_remote_hosts(self):
445+
# Using dict.fromkeys() instead of a set to preserve insertion order (Python 3.7+)
446+
# while still providing O(1) lookup for `host in self._remote_hosts`.
447+
remote_hosts = {}
448+
if self.used_hosts_per_remote_dc > 0:
449+
for datacenter, hosts in self._dc_live_hosts.items():
450+
if datacenter != self.local_dc:
451+
remote_hosts.update(
452+
dict.fromkeys(hosts[:self.used_hosts_per_remote_dc])
453+
)
454+
self._remote_hosts = remote_hosts
455+
456+
def _refresh_non_local_rack_hosts(self):
457+
local_live = self._dc_live_hosts.get(self.local_dc, ())
458+
self._non_local_rack_hosts = [
459+
h for h in local_live if self._rack(h) != self.local_rack
460+
]
461+
442462
def populate(self, cluster, hosts):
443463
for (dc, rack), rack_hosts in groupby(hosts, lambda host: (self._dc(host), self._rack(host))):
444464
self._live_hosts[(dc, rack)] = tuple({*rack_hosts, *self._live_hosts.get((dc, rack), [])})
445465
for dc, dc_hosts in groupby(hosts, lambda host: self._dc(host)):
446466
self._dc_live_hosts[dc] = tuple({*dc_hosts, *self._dc_live_hosts.get(dc, [])})
447467

448468
self._position = randint(0, len(hosts) - 1) if hosts else 0
469+
self._refresh_remote_hosts()
470+
self._refresh_non_local_rack_hosts()
449471

450472
def distance(self, host):
451-
rack = self._rack(host)
452473
dc = self._dc(host)
453-
if rack == self.local_rack and dc == self.local_dc:
454-
return HostDistance.LOCAL_RACK
455-
456474
if dc == self.local_dc:
475+
if self._rack(host) == self.local_rack:
476+
return HostDistance.LOCAL_RACK
457477
return HostDistance.LOCAL
458478

459-
if not self.used_hosts_per_remote_dc:
460-
return HostDistance.IGNORED
461-
462-
dc_hosts = self._dc_live_hosts.get(dc, ())
463-
if not dc_hosts:
464-
return HostDistance.IGNORED
465-
if host in dc_hosts and dc_hosts.index(host) < self.used_hosts_per_remote_dc:
479+
remote_hosts = self._remote_hosts
480+
if host in remote_hosts:
466481
return HostDistance.REMOTE
467-
else:
468-
return HostDistance.IGNORED
482+
return HostDistance.IGNORED
469483

470484
def make_query_plan(self, working_keyspace=None, query=None):
471485
pos = self._position
472486
self._position += 1
473487

474488
local_rack_live = self._live_hosts.get((self.local_dc, self.local_rack), ())
475-
pos = (pos % len(local_rack_live)) if local_rack_live else 0
476-
# Slice the cyclic iterator to start from pos and include the next len(local_live) elements
477-
# This ensures we get exactly one full cycle starting from pos
478-
for host in islice(cycle(local_rack_live), pos, pos + len(local_rack_live)):
479-
yield host
489+
length = len(local_rack_live)
490+
if length:
491+
p = pos % length
492+
for i in range(length):
493+
yield local_rack_live[(p + i) % length]
480494

481-
local_live = [host for host in self._dc_live_hosts.get(self.local_dc, ()) if host.rack != self.local_rack]
482-
pos = (pos % len(local_live)) if local_live else 0
483-
for host in islice(cycle(local_live), pos, pos + len(local_live)):
484-
yield host
495+
local_non_rack = self._non_local_rack_hosts
496+
length = len(local_non_rack)
497+
if length:
498+
p = pos % length
499+
for i in range(length):
500+
yield local_non_rack[(p + i) % length]
485501

486-
# the dict can change, so get candidate DCs iterating over keys of a copy
487-
for dc, remote_live in self._dc_live_hosts.copy().items():
488-
if dc != self.local_dc:
489-
for host in remote_live[:self.used_hosts_per_remote_dc]:
490-
yield host
502+
remote_hosts = self._remote_hosts
503+
for host in remote_hosts:
504+
yield host
491505

492506
def on_up(self, host):
493507
dc = self._dc(host)
494508
rack = self._rack(host)
495509
with self._hosts_lock:
496-
current_rack_hosts = self._live_hosts.get((dc, rack), ())
497-
if host not in current_rack_hosts:
498-
self._live_hosts[(dc, rack)] = current_rack_hosts + (host, )
499510
current_dc_hosts = self._dc_live_hosts.get(dc, ())
500511
if host not in current_dc_hosts:
501512
self._dc_live_hosts[dc] = current_dc_hosts + (host, )
502513

514+
if dc != self.local_dc:
515+
self._refresh_remote_hosts()
516+
else:
517+
self._refresh_non_local_rack_hosts()
518+
519+
current_rack_hosts = self._live_hosts.get((dc, rack), ())
520+
if host not in current_rack_hosts:
521+
self._live_hosts[(dc, rack)] = current_rack_hosts + (host, )
522+
503523
def on_down(self, host):
504524
dc = self._dc(host)
505525
rack = self._rack(host)
506526
with self._hosts_lock:
507-
current_rack_hosts = self._live_hosts.get((dc, rack), ())
508-
if host in current_rack_hosts:
509-
hosts = tuple(h for h in current_rack_hosts if h != host)
510-
if hosts:
511-
self._live_hosts[(dc, rack)] = hosts
512-
else:
513-
del self._live_hosts[(dc, rack)]
514527
current_dc_hosts = self._dc_live_hosts.get(dc, ())
515528
if host in current_dc_hosts:
516529
hosts = tuple(h for h in current_dc_hosts if h != host)
@@ -519,6 +532,19 @@ def on_down(self, host):
519532
else:
520533
del self._dc_live_hosts[dc]
521534

535+
if dc != self.local_dc:
536+
self._refresh_remote_hosts()
537+
else:
538+
self._refresh_non_local_rack_hosts()
539+
540+
current_rack_hosts = self._live_hosts.get((dc, rack), ())
541+
if host in current_rack_hosts:
542+
hosts = tuple(h for h in current_rack_hosts if h != host)
543+
if hosts:
544+
self._live_hosts[(dc, rack)] = hosts
545+
else:
546+
del self._live_hosts[(dc, rack)]
547+
522548
def on_add(self, host):
523549
self.on_up(host)
524550

0 commit comments

Comments
 (0)