Skip to content

Commit f962afe

Browse files
committed
(improvement)Optimize RackAwareRoundRobinPolicy by caching some host distances
Refactor `RackAwareRoundRobinPolicy` to simplify distance calculations and memory usage. Add self._remote_hosts to cache remote hosts distance, self._non_local_rack_hosts for non-local rack host distance. - Only cache `_remote_hosts` to efficiently handle `used_hosts_per_remote_dc`. - Optimize control plane operations (`on_up`, `on_down`) to only rebuild the remote cache when necessary (when remote hosts change or local DC changes). - Snapshot `self._remote_hosts` to a local variable before reads for GIL-free Python 3.13+ safety. Benchmark (100K queries, 45-node/5-DC topology, Python 3.14, median of 5 runs): Policy | Kops/s | vs master | delta | Mem KB ----------------------------------------------------------------- DCAware | 188 | +77% | | 1.5 RackAware | 146 | +115% | +115% | 2.0 TokenAware(DCAware) | 20 | +11% | | 1.7 TokenAware(RackAware) | 19 | +12% | +19% | 2.2 Default(DCAware) | 123 | +35% | | 1.6 HostFilter(DCAware) | 57 | +8% | | 1.7 RackAware more than doubles throughput (+115%), from 68 to 146 Kops/s. TokenAware(RackAware) also benefits transitively (+12% cumulative). Signed-off-by: Yaniv Kaul <yaniv.kaul@scylladb.com>
1 parent 888cd02 commit f962afe

2 files changed

Lines changed: 72 additions & 41 deletions

File tree

cassandra/policies.py

Lines changed: 62 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,8 @@ def __init__(self, local_dc, local_rack, used_hosts_per_remote_dc=0):
370370
self.used_hosts_per_remote_dc = used_hosts_per_remote_dc
371371
self._live_hosts = {}
372372
self._dc_live_hosts = {}
373+
self._remote_hosts = {}
374+
self._non_local_rack_hosts = []
373375
self._endpoints = []
374376
self._position = 0
375377
LoadBalancingPolicy.__init__(self)
@@ -380,6 +382,24 @@ def _rack(self, host):
380382
def _dc(self, host):
381383
return host.datacenter or self.local_dc
382384

385+
def _refresh_remote_hosts(self):
386+
# Using dict.fromkeys() instead of a set to preserve insertion order (Python 3.7+)
387+
# while still providing O(1) lookup for `host in self._remote_hosts`.
388+
remote_hosts = {}
389+
if self.used_hosts_per_remote_dc > 0:
390+
for datacenter, hosts in self._dc_live_hosts.items():
391+
if datacenter != self.local_dc:
392+
remote_hosts.update(
393+
dict.fromkeys(hosts[: self.used_hosts_per_remote_dc])
394+
)
395+
self._remote_hosts = remote_hosts
396+
397+
def _refresh_non_local_rack_hosts(self):
398+
local_live = self._dc_live_hosts.get(self.local_dc, ())
399+
self._non_local_rack_hosts = [
400+
h for h in local_live if self._rack(h) != self.local_rack
401+
]
402+
383403
def populate(self, cluster, hosts):
384404
for (dc, rack), rack_hosts in groupby(
385405
hosts, lambda host: (self._dc(host), self._rack(host))
@@ -393,75 +413,63 @@ def populate(self, cluster, hosts):
393413
)
394414

395415
self._position = randint(0, len(hosts) - 1) if hosts else 0
416+
self._refresh_remote_hosts()
417+
self._refresh_non_local_rack_hosts()
396418

397419
def distance(self, host):
398-
rack = self._rack(host)
399420
dc = self._dc(host)
400-
if rack == self.local_rack and dc == self.local_dc:
401-
return HostDistance.LOCAL_RACK
402-
403421
if dc == self.local_dc:
422+
if self._rack(host) == self.local_rack:
423+
return HostDistance.LOCAL_RACK
404424
return HostDistance.LOCAL
405425

406-
if not self.used_hosts_per_remote_dc:
407-
return HostDistance.IGNORED
408-
409-
dc_hosts = self._dc_live_hosts.get(dc, ())
410-
if not dc_hosts:
411-
return HostDistance.IGNORED
412-
if host in dc_hosts and dc_hosts.index(host) < self.used_hosts_per_remote_dc:
426+
remote_hosts = self._remote_hosts
427+
if host in remote_hosts:
413428
return HostDistance.REMOTE
414-
else:
415-
return HostDistance.IGNORED
429+
return HostDistance.IGNORED
416430

417431
def make_query_plan(self, working_keyspace=None, query=None):
418432
pos = self._position
419433
self._position += 1
420434

421435
local_rack_live = self._live_hosts.get((self.local_dc, self.local_rack), ())
422-
pos = (pos % len(local_rack_live)) if local_rack_live else 0
423-
# Slice the cyclic iterator to start from pos and include the next len(local_live) elements
424-
# This ensures we get exactly one full cycle starting from pos
425-
for host in islice(cycle(local_rack_live), pos, pos + len(local_rack_live)):
426-
yield host
436+
length = len(local_rack_live)
437+
if length:
438+
p = pos % length
439+
for host in islice(cycle(local_rack_live), p, p + length):
440+
yield host
427441

428-
local_live = [
429-
host
430-
for host in self._dc_live_hosts.get(self.local_dc, ())
431-
if host.rack != self.local_rack
432-
]
433-
pos = (pos % len(local_live)) if local_live else 0
434-
for host in islice(cycle(local_live), pos, pos + len(local_live)):
435-
yield host
442+
local_non_rack = self._non_local_rack_hosts
443+
length = len(local_non_rack)
444+
if length:
445+
p = pos % length
446+
for host in islice(cycle(local_non_rack), p, p + length):
447+
yield host
436448

437-
# the dict can change, so get candidate DCs iterating over keys of a copy
438-
for dc, remote_live in self._dc_live_hosts.copy().items():
439-
if dc != self.local_dc:
440-
for host in remote_live[: self.used_hosts_per_remote_dc]:
441-
yield host
449+
for host in self._remote_hosts:
450+
yield host
442451

443452
def on_up(self, host):
444453
dc = self._dc(host)
445454
rack = self._rack(host)
446455
with self._hosts_lock:
447-
current_rack_hosts = self._live_hosts.get((dc, rack), ())
448-
if host not in current_rack_hosts:
449-
self._live_hosts[(dc, rack)] = current_rack_hosts + (host,)
450456
current_dc_hosts = self._dc_live_hosts.get(dc, ())
451457
if host not in current_dc_hosts:
452458
self._dc_live_hosts[dc] = current_dc_hosts + (host,)
453459

460+
if dc != self.local_dc:
461+
self._refresh_remote_hosts()
462+
else:
463+
self._refresh_non_local_rack_hosts()
464+
465+
current_rack_hosts = self._live_hosts.get((dc, rack), ())
466+
if host not in current_rack_hosts:
467+
self._live_hosts[(dc, rack)] = current_rack_hosts + (host,)
468+
454469
def on_down(self, host):
455470
dc = self._dc(host)
456471
rack = self._rack(host)
457472
with self._hosts_lock:
458-
current_rack_hosts = self._live_hosts.get((dc, rack), ())
459-
if host in current_rack_hosts:
460-
hosts = tuple(h for h in current_rack_hosts if h != host)
461-
if hosts:
462-
self._live_hosts[(dc, rack)] = hosts
463-
else:
464-
del self._live_hosts[(dc, rack)]
465473
current_dc_hosts = self._dc_live_hosts.get(dc, ())
466474
if host in current_dc_hosts:
467475
hosts = tuple(h for h in current_dc_hosts if h != host)
@@ -470,6 +478,19 @@ def on_down(self, host):
470478
else:
471479
del self._dc_live_hosts[dc]
472480

481+
if dc != self.local_dc:
482+
self._refresh_remote_hosts()
483+
else:
484+
self._refresh_non_local_rack_hosts()
485+
486+
current_rack_hosts = self._live_hosts.get((dc, rack), ())
487+
if host in current_rack_hosts:
488+
hosts = tuple(h for h in current_rack_hosts if h != host)
489+
if hosts:
490+
self._live_hosts[(dc, rack)] = hosts
491+
else:
492+
del self._live_hosts[(dc, rack)]
493+
473494
def on_add(self, host):
474495
self.on_up(host)
475496

tests/unit/test_policies.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,16 @@ def test_get_distance(self, policy_specialization, constructor_args):
274274
assert policy.distance(host) == HostDistance.LOCAL_RACK
275275

276276
# same dc different rack
277+
# Reset policy state to simulate a fresh view or handle the "move" correctly
278+
# In a real scenario, a host moving racks would be handled by on_down/on_up or distinct host objects.
279+
# Here we are reusing the same policy instance with populate(), which merges hosts.
280+
# To avoid the host existing in both rack1 and rack2 buckets due to address equality,
281+
# we clear the internal state.
282+
if hasattr(policy, '_live_hosts'):
283+
policy._live_hosts.clear()
284+
if hasattr(policy, '_dc_live_hosts'):
285+
policy._dc_live_hosts.clear()
286+
277287
host = Host(DefaultEndPoint("ip1"), SimpleConvictionPolicy, host_id=uuid.uuid4())
278288
host.set_location_info("dc1", "rack2")
279289
policy.populate(Mock(), [host])

0 commit comments

Comments
 (0)