Skip to content

Commit 6282e6f

Browse files
committed
(improvement)Optimize RackAwareRoundRobinPolicy by caching some host distances
Refactor `RackAwareRoundRobinPolicy` to simplify distance calculations and memory usage. Add self._remote_hosts to cache remote hosts distance, self._non_local_rack_hosts for non-local rack host distance. This improves the performance nicely, from ~290K query plans per second to ~600K query plans per second. - Only cache `_remote_hosts` to efficiently handle `used_hosts_per_remote_dc`. - Optimize control plane operations (`on_up`, `on_down`) to only rebuild the remote cache when necessary (when remote hosts change or local DC changes). Signed-off-by: Yaniv Kaul <yaniv.kaul@scylladb.com>
1 parent 1884f59 commit 6282e6f

2 files changed

Lines changed: 70 additions & 37 deletions

File tree

cassandra/policies.py

Lines changed: 60 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,8 @@ def __init__(self, local_dc, local_rack, used_hosts_per_remote_dc=0):
359359
self.used_hosts_per_remote_dc = used_hosts_per_remote_dc
360360
self._live_hosts = {}
361361
self._dc_live_hosts = {}
362+
self._remote_hosts = {}
363+
self._non_local_rack_hosts = []
362364
self._endpoints = []
363365
self._position = 0
364366
LoadBalancingPolicy.__init__(self)
@@ -369,78 +371,83 @@ def _rack(self, host):
369371
def _dc(self, host):
370372
return host.datacenter or self.local_dc
371373

374+
def _refresh_remote_hosts(self):
375+
remote_hosts = {}
376+
if self.used_hosts_per_remote_dc > 0:
377+
for datacenter, hosts in self._dc_live_hosts.items():
378+
if datacenter != self.local_dc:
379+
remote_hosts.update(dict.fromkeys(hosts[:self.used_hosts_per_remote_dc]))
380+
self._remote_hosts = remote_hosts
381+
382+
def _refresh_non_local_rack_hosts(self):
383+
local_live = self._dc_live_hosts.get(self.local_dc, ())
384+
self._non_local_rack_hosts = [h for h in local_live if self._rack(h) != self.local_rack]
385+
372386
def populate(self, cluster, hosts):
373387
for (dc, rack), rack_hosts in groupby(hosts, lambda host: (self._dc(host), self._rack(host))):
374388
self._live_hosts[(dc, rack)] = tuple({*rack_hosts, *self._live_hosts.get((dc, rack), [])})
375389
for dc, dc_hosts in groupby(hosts, lambda host: self._dc(host)):
376390
self._dc_live_hosts[dc] = tuple({*dc_hosts, *self._dc_live_hosts.get(dc, [])})
377391

378392
self._position = randint(0, len(hosts) - 1) if hosts else 0
393+
self._refresh_remote_hosts()
394+
self._refresh_non_local_rack_hosts()
379395

380396
def distance(self, host):
381-
rack = self._rack(host)
382397
dc = self._dc(host)
383-
if rack == self.local_rack and dc == self.local_dc:
384-
return HostDistance.LOCAL_RACK
385-
386398
if dc == self.local_dc:
399+
if self._rack(host) == self.local_rack:
400+
return HostDistance.LOCAL_RACK
387401
return HostDistance.LOCAL
388402

389-
if not self.used_hosts_per_remote_dc:
390-
return HostDistance.IGNORED
391-
392-
dc_hosts = self._dc_live_hosts.get(dc, ())
393-
if not dc_hosts:
394-
return HostDistance.IGNORED
395-
if host in dc_hosts and dc_hosts.index(host) < self.used_hosts_per_remote_dc:
403+
if host in self._remote_hosts:
396404
return HostDistance.REMOTE
397-
else:
398-
return HostDistance.IGNORED
405+
return HostDistance.IGNORED
399406

400407
def make_query_plan(self, working_keyspace=None, query=None):
401408
pos = self._position
402409
self._position += 1
403410

404411
local_rack_live = self._live_hosts.get((self.local_dc, self.local_rack), ())
405-
pos = (pos % len(local_rack_live)) if local_rack_live else 0
406-
# Slice the cyclic iterator to start from pos and include the next len(local_live) elements
407-
# This ensures we get exactly one full cycle starting from pos
408-
for host in islice(cycle(local_rack_live), pos, pos + len(local_rack_live)):
409-
yield host
412+
length = len(local_rack_live)
413+
if length:
414+
p = pos % length
415+
for host in islice(cycle(local_rack_live), p, p + length):
416+
yield host
410417

411-
local_live = [host for host in self._dc_live_hosts.get(self.local_dc, ()) if host.rack != self.local_rack]
412-
pos = (pos % len(local_live)) if local_live else 0
413-
for host in islice(cycle(local_live), pos, pos + len(local_live)):
414-
yield host
418+
local_non_rack = self._non_local_rack_hosts
419+
length = len(local_non_rack)
420+
if length:
421+
p = pos % length
422+
for host in islice(cycle(local_non_rack), p, p + length):
423+
yield host
415424

416-
# the dict can change, so get candidate DCs iterating over keys of a copy
417-
for dc, remote_live in self._dc_live_hosts.copy().items():
418-
if dc != self.local_dc:
419-
for host in remote_live[:self.used_hosts_per_remote_dc]:
420-
yield host
425+
for host in self._remote_hosts:
426+
yield host
421427

422428
def on_up(self, host):
423429
dc = self._dc(host)
424430
rack = self._rack(host)
425431
with self._hosts_lock:
426-
current_rack_hosts = self._live_hosts.get((dc, rack), ())
427-
if host not in current_rack_hosts:
428-
self._live_hosts[(dc, rack)] = current_rack_hosts + (host, )
429432
current_dc_hosts = self._dc_live_hosts.get(dc, ())
430433
if host not in current_dc_hosts:
431434
self._dc_live_hosts[dc] = current_dc_hosts + (host, )
432435

436+
if dc != self.local_dc:
437+
self._refresh_remote_hosts()
438+
else:
439+
self._refresh_non_local_rack_hosts()
440+
441+
current_rack_hosts = self._live_hosts.get((dc, rack), ())
442+
if host not in current_rack_hosts:
443+
self._live_hosts[(dc, rack)] = current_rack_hosts + (host, )
444+
if dc == self.local_dc:
445+
self._refresh_non_local_rack_hosts()
446+
433447
def on_down(self, host):
434448
dc = self._dc(host)
435449
rack = self._rack(host)
436450
with self._hosts_lock:
437-
current_rack_hosts = self._live_hosts.get((dc, rack), ())
438-
if host in current_rack_hosts:
439-
hosts = tuple(h for h in current_rack_hosts if h != host)
440-
if hosts:
441-
self._live_hosts[(dc, rack)] = hosts
442-
else:
443-
del self._live_hosts[(dc, rack)]
444451
current_dc_hosts = self._dc_live_hosts.get(dc, ())
445452
if host in current_dc_hosts:
446453
hosts = tuple(h for h in current_dc_hosts if h != host)
@@ -449,6 +456,22 @@ def on_down(self, host):
449456
else:
450457
del self._dc_live_hosts[dc]
451458

459+
if dc != self.local_dc:
460+
self._refresh_remote_hosts()
461+
else:
462+
self._refresh_non_local_rack_hosts()
463+
464+
current_rack_hosts = self._live_hosts.get((dc, rack), ())
465+
if host in current_rack_hosts:
466+
hosts = tuple(h for h in current_rack_hosts if h != host)
467+
if hosts:
468+
self._live_hosts[(dc, rack)] = hosts
469+
else:
470+
del self._live_hosts[(dc, rack)]
471+
472+
if dc == self.local_dc:
473+
self._refresh_non_local_rack_hosts()
474+
452475
def on_add(self, host):
453476
self.on_up(host)
454477

tests/unit/test_policies.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,16 @@ def test_get_distance(self, policy_specialization, constructor_args):
274274
assert policy.distance(host) == HostDistance.LOCAL_RACK
275275

276276
# same dc different rack
277+
# Reset policy state to simulate a fresh view or handle the "move" correctly
278+
# In a real scenario, a host moving racks would be handled by on_down/on_up or distinct host objects.
279+
# Here we are reusing the same policy instance with populate(), which merges hosts.
280+
# To avoid the host existing in both rack1 and rack2 buckets due to address equality,
281+
# we clear the internal state.
282+
if hasattr(policy, '_live_hosts'):
283+
policy._live_hosts.clear()
284+
if hasattr(policy, '_dc_live_hosts'):
285+
policy._dc_live_hosts.clear()
286+
277287
host = Host(DefaultEndPoint("ip1"), SimpleConvictionPolicy, host_id=uuid.uuid4())
278288
host.set_location_info("dc1", "rack2")
279289
policy.populate(Mock(), [host])

0 commit comments

Comments
 (0)