Skip to content

Commit d580eed

Browse files
committed
perf: optimize DCAwareRoundRobinPolicy with cached remote hosts
Introduce _remote_hosts dict to cache REMOTE hosts, enabling O(1) distance lookups instead of scanning per-DC host lists. Replace islice(cycle(...)) with index arithmetic in make_query_plan. Call _refresh_remote_hosts() on topology changes.
1 parent 8e6c4d4 commit d580eed

1 file changed

Lines changed: 37 additions & 21 deletions

File tree

cassandra/policies.py

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
import random
1515

16-
from collections import namedtuple
16+
from collections import namedtuple, OrderedDict
1717
from itertools import islice, cycle, groupby, repeat
1818
import logging
1919
from random import randint, shuffle
@@ -244,34 +244,41 @@ def __init__(self, local_dc='', used_hosts_per_remote_dc=0):
244244
self.local_dc = local_dc
245245
self.used_hosts_per_remote_dc = used_hosts_per_remote_dc
246246
self._dc_live_hosts = {}
247+
self._remote_hosts = {}
247248
self._position = 0
248249
LoadBalancingPolicy.__init__(self)
249250

250251
def _dc(self, host):
251252
return host.datacenter or self.local_dc
252253

254+
def _refresh_remote_hosts(self):
255+
# Using dict.fromkeys() instead of a set to preserve insertion order (Python 3.7+)
256+
# while still providing O(1) lookup for `host in self._remote_hosts`.
257+
remote_hosts = {}
258+
if self.used_hosts_per_remote_dc > 0:
259+
for datacenter, hosts in self._dc_live_hosts.items():
260+
if datacenter != self.local_dc:
261+
remote_hosts.update(
262+
dict.fromkeys(hosts[:self.used_hosts_per_remote_dc])
263+
)
264+
self._remote_hosts = remote_hosts
265+
253266
def populate(self, cluster, hosts):
254267
for dc, dc_hosts in groupby(hosts, lambda h: self._dc(h)):
255268
self._dc_live_hosts[dc] = tuple({*dc_hosts, *self._dc_live_hosts.get(dc, [])})
256269

257270
self._position = randint(0, len(hosts) - 1) if hosts else 0
271+
self._refresh_remote_hosts()
258272

259273
def distance(self, host):
260274
dc = self._dc(host)
261275
if dc == self.local_dc:
262276
return HostDistance.LOCAL
263277

264-
if not self.used_hosts_per_remote_dc:
265-
return HostDistance.IGNORED
266-
else:
267-
dc_hosts = self._dc_live_hosts.get(dc)
268-
if not dc_hosts:
269-
return HostDistance.IGNORED
270-
271-
if host in list(dc_hosts)[:self.used_hosts_per_remote_dc]:
272-
return HostDistance.REMOTE
273-
else:
274-
return HostDistance.IGNORED
278+
remote_hosts = self._remote_hosts
279+
if host in remote_hosts:
280+
return HostDistance.REMOTE
281+
return HostDistance.IGNORED
275282

276283
def make_query_plan(self, working_keyspace=None, query=None):
277284
# not thread-safe, but we don't care much about lost increments
@@ -280,32 +287,38 @@ def make_query_plan(self, working_keyspace=None, query=None):
280287
self._position += 1
281288

282289
local_live = self._dc_live_hosts.get(self.local_dc, ())
283-
pos = (pos % len(local_live)) if local_live else 0
284-
for host in islice(cycle(local_live), pos, pos + len(local_live)):
285-
yield host
290+
length = len(local_live)
291+
if length:
292+
pos %= length
293+
for i in range(length):
294+
yield local_live[(pos + i) % length]
286295

287-
# the dict can change, so get candidate DCs iterating over keys of a copy
288-
other_dcs = [dc for dc in self._dc_live_hosts.copy().keys() if dc != self.local_dc]
289-
for dc in other_dcs:
290-
remote_live = self._dc_live_hosts.get(dc, ())
291-
for host in remote_live[:self.used_hosts_per_remote_dc]:
292-
yield host
296+
remote_hosts = self._remote_hosts
297+
for host in remote_hosts:
298+
yield host
293299

294300
def on_up(self, host):
295301
# not worrying about threads because this will happen during
296302
# control connection startup/refresh
303+
refresh_remote = False
297304
if not self.local_dc and host.datacenter:
298305
self.local_dc = host.datacenter
299306
log.info("Using datacenter '%s' for DCAwareRoundRobinPolicy (via host '%s'); "
300307
"if incorrect, please specify a local_dc to the constructor, "
301308
"or limit contact points to local cluster nodes" %
302309
(self.local_dc, host.endpoint))
310+
refresh_remote = True
303311

304312
dc = self._dc(host)
305313
with self._hosts_lock:
306314
current_hosts = self._dc_live_hosts.get(dc, ())
307315
if host not in current_hosts:
308316
self._dc_live_hosts[dc] = current_hosts + (host, )
317+
if dc != self.local_dc:
318+
refresh_remote = True
319+
320+
if refresh_remote:
321+
self._refresh_remote_hosts()
309322

310323
def on_down(self, host):
311324
dc = self._dc(host)
@@ -318,6 +331,9 @@ def on_down(self, host):
318331
else:
319332
del self._dc_live_hosts[dc]
320333

334+
if dc != self.local_dc:
335+
self._refresh_remote_hosts()
336+
321337
def on_add(self, host):
322338
self.on_up(host)
323339

0 commit comments

Comments
 (0)