Skip to content

Commit b0359a6

Browse files
committed
(improvement)Optimize DCAwareRoundRobinPolicy with host distance caching
Refactor `DCAwareRoundRobinPolicy` to use a Copy-On-Write (COW) strategy for managing host distances. Key changes: - Introduce `_hosts_by_distance` to cache `LOCAL` and `REMOTE` hosts, enabling O(1) distance lookups and thread-safe iteration during query planning. `IGNORED` hosts do not need to be stored in the cache. - Optimize control plane operations (`on_up`, `on_down`): - Add `_add_local_distance` and `_remove_local_distance` to handle local DC node changes via atomic dictionary swaps, avoiding full cache rebuilds. - Update `_refresh_distances` to support partial refreshes, reusing existing `LOCAL` or `REMOTE` structures when only one DC changes. Signed-off-by: Yaniv Kaul <yaniv.kaul@scylladb.com>
1 parent 711a7eb commit b0359a6

1 file changed

Lines changed: 77 additions & 20 deletions

File tree

cassandra/policies.py

Lines changed: 77 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -244,34 +244,81 @@ def __init__(self, local_dc='', used_hosts_per_remote_dc=0):
244244
self.local_dc = local_dc
245245
self.used_hosts_per_remote_dc = used_hosts_per_remote_dc
246246
self._dc_live_hosts = {}
247+
self._hosts_by_distance = {
248+
HostDistance.LOCAL: {},
249+
HostDistance.REMOTE: {}
250+
}
247251
self._position = 0
248252
LoadBalancingPolicy.__init__(self)
249253

250254
def _dc(self, host):
251255
return host.datacenter or self.local_dc
252256

257+
def _refresh_distances(self, dc=None):
258+
# Always called under self._hosts_lock, so it's safe to iterate over self._dc_live_hosts
259+
260+
if dc is None:
261+
# Full rebuild
262+
local_hosts = dict.fromkeys(self._dc_live_hosts.get(self.local_dc, ()))
263+
remote_hosts = {}
264+
if self.used_hosts_per_remote_dc > 0:
265+
for datacenter, hosts in self._dc_live_hosts.items():
266+
if datacenter != self.local_dc:
267+
remote_hosts.update(dict.fromkeys(hosts[:self.used_hosts_per_remote_dc]))
268+
elif dc == self.local_dc:
269+
# Rebuild local, reuse remote
270+
local_hosts = dict.fromkeys(self._dc_live_hosts.get(self.local_dc, ()))
271+
remote_hosts = self._hosts_by_distance[HostDistance.REMOTE]
272+
else:
273+
# Rebuild remote, reuse local
274+
local_hosts = self._hosts_by_distance[HostDistance.LOCAL]
275+
remote_hosts = {}
276+
if self.used_hosts_per_remote_dc > 0:
277+
for datacenter, hosts in self._dc_live_hosts.items():
278+
if datacenter != self.local_dc:
279+
remote_hosts.update(dict.fromkeys(hosts[:self.used_hosts_per_remote_dc]))
280+
281+
self._hosts_by_distance = {
282+
HostDistance.LOCAL: local_hosts,
283+
HostDistance.REMOTE: remote_hosts
284+
}
285+
286+
def _add_local_distance(self, host):
287+
distances = self._hosts_by_distance
288+
if host not in distances[HostDistance.LOCAL]:
289+
new_local = distances[HostDistance.LOCAL].copy()
290+
new_local[host] = None
291+
self._hosts_by_distance = {
292+
HostDistance.LOCAL: new_local,
293+
HostDistance.REMOTE: distances[HostDistance.REMOTE]
294+
}
295+
296+
def _remove_local_distance(self, host):
297+
distances = self._hosts_by_distance
298+
if host in distances[HostDistance.LOCAL]:
299+
new_local = distances[HostDistance.LOCAL].copy()
300+
del new_local[host]
301+
self._hosts_by_distance = {
302+
HostDistance.LOCAL: new_local,
303+
HostDistance.REMOTE: distances[HostDistance.REMOTE]
304+
}
305+
253306
def populate(self, cluster, hosts):
254307
for dc, dc_hosts in groupby(hosts, lambda h: self._dc(h)):
255308
self._dc_live_hosts[dc] = tuple({*dc_hosts, *self._dc_live_hosts.get(dc, [])})
256309

257310
self._position = randint(0, len(hosts) - 1) if hosts else 0
311+
self._refresh_distances()
258312

259313
def distance(self, host):
260-
dc = self._dc(host)
261-
if dc == self.local_dc:
314+
# Capture the reference to the dict to ensure we check a consistent snapshot
315+
# (avoiding the case where we check LOCAL, fail, context switch, update happens, check REMOTE, fail)
316+
distances = self._hosts_by_distance
317+
if host in distances[HostDistance.LOCAL]:
262318
return HostDistance.LOCAL
263-
264-
if not self.used_hosts_per_remote_dc:
265-
return HostDistance.IGNORED
266-
else:
267-
dc_hosts = self._dc_live_hosts.get(dc)
268-
if not dc_hosts:
269-
return HostDistance.IGNORED
270-
271-
if host in list(dc_hosts)[:self.used_hosts_per_remote_dc]:
272-
return HostDistance.REMOTE
273-
else:
274-
return HostDistance.IGNORED
319+
elif host in distances[HostDistance.REMOTE]:
320+
return HostDistance.REMOTE
321+
return HostDistance.IGNORED
275322

276323
def make_query_plan(self, working_keyspace=None, query=None):
277324
# not thread-safe, but we don't care much about lost increments
@@ -284,29 +331,34 @@ def make_query_plan(self, working_keyspace=None, query=None):
284331
for host in islice(cycle(local_live), pos, pos + len(local_live)):
285332
yield host
286333

287-
# the dict can change, so get candidate DCs iterating over keys of a copy
288-
other_dcs = [dc for dc in self._dc_live_hosts.copy().keys() if dc != self.local_dc]
289-
for dc in other_dcs:
290-
remote_live = self._dc_live_hosts.get(dc, ())
291-
for host in remote_live[:self.used_hosts_per_remote_dc]:
292-
yield host
334+
for host in self._hosts_by_distance[HostDistance.REMOTE]:
335+
yield host
293336

294337
def on_up(self, host):
295338
# not worrying about threads because this will happen during
296339
# control connection startup/refresh
340+
refresh_all = False
297341
if not self.local_dc and host.datacenter:
298342
self.local_dc = host.datacenter
299343
log.info("Using datacenter '%s' for DCAwareRoundRobinPolicy (via host '%s'); "
300344
"if incorrect, please specify a local_dc to the constructor, "
301345
"or limit contact points to local cluster nodes" %
302346
(self.local_dc, host.endpoint))
347+
refresh_all = True
303348

304349
dc = self._dc(host)
305350
with self._hosts_lock:
306351
current_hosts = self._dc_live_hosts.get(dc, ())
307352
if host not in current_hosts:
308353
self._dc_live_hosts[dc] = current_hosts + (host, )
309354

355+
if refresh_all:
356+
self._refresh_distances()
357+
elif dc == self.local_dc:
358+
self._add_local_distance(host)
359+
else:
360+
self._refresh_distances(dc)
361+
310362
def on_down(self, host):
311363
dc = self._dc(host)
312364
with self._hosts_lock:
@@ -318,6 +370,11 @@ def on_down(self, host):
318370
else:
319371
del self._dc_live_hosts[dc]
320372

373+
if dc == self.local_dc:
374+
self._remove_local_distance(host)
375+
else:
376+
self._refresh_distances(dc)
377+
321378
def on_add(self, host):
322379
self.on_up(host)
323380

0 commit comments

Comments
 (0)