Skip to content

Commit 6a4878d

Browse files
committed
fix: thread-safety, is_up check, LRU size, and code deduplication
- Change _non_local_rack_hosts to tuple for thread-safe snapshot reads, matching codebase convention (tuples/frozensets for host collections). - Restore is_up check for non-replica hosts in TokenAwarePolicy to match master's yield_in_order behavior. - Reduce LRU cache default from 1024 to 256 entries. - Deduplicate make_query_plan by delegating to make_query_plan_with_exclusion in DCAware and RackAware policies. Preserve lazy remote_hosts reads so concurrent on_up/on_down changes remain visible.
1 parent 516070c commit 6a4878d

1 file changed

Lines changed: 20 additions & 47 deletions

File tree

cassandra/policies.py

Lines changed: 20 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -307,21 +307,7 @@ def distance(self, host):
307307
return HostDistance.IGNORED
308308

309309
def make_query_plan(self, working_keyspace=None, query=None):
310-
# not thread-safe, but we don't care much about lost increments
311-
# for the purposes of load balancing
312-
pos = self._position
313-
self._position += 1
314-
315-
local_live = self._dc_live_hosts.get(self.local_dc, ())
316-
length = len(local_live)
317-
if length:
318-
pos %= length
319-
for i in range(length):
320-
yield local_live[(pos + i) % length]
321-
322-
remote_hosts = self._remote_hosts
323-
for host in remote_hosts:
324-
yield host
310+
return self.make_query_plan_with_exclusion(working_keyspace, query)
325311

326312
def make_query_plan_with_exclusion(self, working_keyspace=None, query=None, excluded=()):
327313
# not thread-safe, but we don't care much about lost increments
@@ -331,13 +317,14 @@ def make_query_plan_with_exclusion(self, working_keyspace=None, query=None, excl
331317

332318
local_live = self._dc_live_hosts.get(self.local_dc, ())
333319
length = len(local_live)
334-
remote_hosts = self._remote_hosts
335320
if not excluded:
336321
if length:
337322
pos %= length
338323
for i in range(length):
339324
yield local_live[(pos + i) % length]
340-
for host in remote_hosts:
325+
# Read remote_hosts lazily after yielding all local hosts,
326+
# so concurrent on_up/on_down changes are visible.
327+
for host in self._remote_hosts:
341328
yield host
342329
return
343330

@@ -352,7 +339,7 @@ def make_query_plan_with_exclusion(self, working_keyspace=None, query=None, excl
352339
continue
353340
yield host
354341

355-
for host in remote_hosts:
342+
for host in self._remote_hosts:
356343
if host in excluded:
357344
continue
358345
yield host
@@ -430,7 +417,7 @@ def __init__(self, local_dc, local_rack, used_hosts_per_remote_dc=0):
430417
self._live_hosts = {}
431418
self._dc_live_hosts = {}
432419
self._remote_hosts = {}
433-
self._non_local_rack_hosts = []
420+
self._non_local_rack_hosts = ()
434421
self._endpoints = []
435422
self._position = 0
436423
LoadBalancingPolicy.__init__(self)
@@ -455,9 +442,9 @@ def _refresh_remote_hosts(self):
455442

456443
def _refresh_non_local_rack_hosts(self):
457444
local_live = self._dc_live_hosts.get(self.local_dc, ())
458-
self._non_local_rack_hosts = [
445+
self._non_local_rack_hosts = tuple(
459446
h for h in local_live if self._rack(h) != self.local_rack
460-
]
447+
)
461448

462449
def populate(self, cluster, hosts):
463450
for (dc, rack), rack_hosts in groupby(hosts, lambda host: (self._dc(host), self._rack(host))):
@@ -482,34 +469,14 @@ def distance(self, host):
482469
return HostDistance.IGNORED
483470

484471
def make_query_plan(self, working_keyspace=None, query=None):
485-
pos = self._position
486-
self._position += 1
487-
488-
local_rack_live = self._live_hosts.get((self.local_dc, self.local_rack), ())
489-
length = len(local_rack_live)
490-
if length:
491-
p = pos % length
492-
for i in range(length):
493-
yield local_rack_live[(p + i) % length]
494-
495-
local_non_rack = self._non_local_rack_hosts
496-
length = len(local_non_rack)
497-
if length:
498-
p = pos % length
499-
for i in range(length):
500-
yield local_non_rack[(p + i) % length]
501-
502-
remote_hosts = self._remote_hosts
503-
for host in remote_hosts:
504-
yield host
472+
return self.make_query_plan_with_exclusion(working_keyspace, query)
505473

506474
def make_query_plan_with_exclusion(self, working_keyspace=None, query=None, excluded=()):
507475
pos = self._position
508476
self._position += 1
509477

510478
local_rack_live = self._live_hosts.get((self.local_dc, self.local_rack), ())
511479
length = len(local_rack_live)
512-
remote_hosts = self._remote_hosts
513480
if not excluded:
514481
if length:
515482
p = pos % length
@@ -523,7 +490,9 @@ def make_query_plan_with_exclusion(self, working_keyspace=None, query=None, excl
523490
for i in range(length):
524491
yield local_non_rack[(p + i) % length]
525492

526-
for host in remote_hosts:
493+
# Read remote_hosts lazily after yielding all local hosts,
494+
# so concurrent on_up/on_down changes are visible.
495+
for host in self._remote_hosts:
527496
yield host
528497
return
529498

@@ -548,7 +517,7 @@ def make_query_plan_with_exclusion(self, working_keyspace=None, query=None, excl
548517
continue
549518
yield host
550519

551-
for host in remote_hosts:
520+
for host in self._remote_hosts:
552521
if host in excluded:
553522
continue
554523
yield host
@@ -617,7 +586,7 @@ class TokenAwarePolicy(LoadBalancingPolicy):
617586
If no :attr:`~.Statement.routing_key` is set on the query, the child
618587
policy's query plan will be used as is.
619588
620-
An LRU cache of size :attr:`cache_replicas_size` (default 1024) avoids
589+
An LRU cache of size :attr:`cache_replicas_size` (default 256) avoids
621590
repeated token-to-replica lookups for the same (keyspace, routing_key)
622591
pair. Set to 0 to disable caching. The cache is automatically
623592
invalidated when the cluster topology changes.
@@ -630,7 +599,7 @@ class TokenAwarePolicy(LoadBalancingPolicy):
630599
Yield local replicas in a random order.
631600
"""
632601

633-
def __init__(self, child_policy, shuffle_replicas=True, cache_replicas_size=1024):
602+
def __init__(self, child_policy, shuffle_replicas=True, cache_replicas_size=256):
634603
super().__init__()
635604
self._child_policy = child_policy
636605
self.shuffle_replicas = shuffle_replicas
@@ -779,12 +748,16 @@ def make_query_plan(self, working_keyspace=None, query=None):
779748
# (local_rack -> local -> remote), so we can stream directly.
780749
# For other child policies we must re-sort by distance.
781750
if isinstance(child, (DCAwareRoundRobinPolicy, RackAwareRoundRobinPolicy)):
782-
yield from child.make_query_plan_with_exclusion(keyspace, query, yielded)
751+
for host in child.make_query_plan_with_exclusion(keyspace, query, yielded):
752+
if host.is_up:
753+
yield host
783754
else:
784755
remaining_local_rack = []
785756
remaining_local = []
786757
remaining_remote = []
787758
for host in child.make_query_plan_with_exclusion(keyspace, query, yielded):
759+
if not host.is_up:
760+
continue
788761
d = child_distance(host)
789762
if d == HostDistance.LOCAL_RACK:
790763
remaining_local_rack.append(host)

0 commit comments

Comments
 (0)