Skip to content

Commit 3dca0c2

Browse files
committed
perf: add LRU replica cache and optimize TokenAwarePolicy query plan
- Add LRU cache (default 1024 entries) for token-to-replicas lookups, auto-invalidated on topology changes (token_map identity check). - Sort replicas by distance (LOCAL_RACK > LOCAL > REMOTE) in a single pass instead of iterating three times. - Skip distance re-sorting for DCAware/RackAware child policies since they already yield in distance order; fallback re-sort for others. - LWT queries skip replica shuffling for deterministic plans. - Use make_query_plan_with_exclusion to avoid re-yielding replicas.
1 parent 4f01775 commit 3dca0c2

1 file changed

Lines changed: 144 additions & 26 deletions

File tree

cassandra/policies.py

Lines changed: 144 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,11 @@ class TokenAwarePolicy(LoadBalancingPolicy):
616616
617617
If no :attr:`~.Statement.routing_key` is set on the query, the child
618618
policy's query plan will be used as is.
619+
620+
An LRU cache of size :attr:`cache_replicas_size` (default 1024) avoids
621+
repeated token-to-replica lookups for the same (keyspace, routing_key)
622+
pair. Set to 0 to disable caching. The cache is automatically
623+
invalidated when the cluster topology changes.
619624
"""
620625

621626
_child_policy = None
@@ -625,9 +630,15 @@ class TokenAwarePolicy(LoadBalancingPolicy):
625630
Yield local replicas in a random order.
626631
"""
627632

628-
def __init__(self, child_policy, shuffle_replicas=True):
633+
def __init__(self, child_policy, shuffle_replicas=True, cache_replicas_size=1024):
634+
super().__init__()
629635
self._child_policy = child_policy
630636
self.shuffle_replicas = shuffle_replicas
637+
self._cluster_metadata = None
638+
self._cache_replicas_size = max(0, cache_replicas_size)
639+
self._replica_cache = OrderedDict()
640+
self._replica_cache_token_map_ref = None
641+
self._cache_lock = Lock()
631642

632643
def populate(self, cluster, hosts):
633644
self._cluster_metadata = cluster.metadata
@@ -645,40 +656,147 @@ def check_supported(self):
645656
def distance(self, *args, **kwargs):
646657
return self._child_policy.distance(*args, **kwargs)
647658

659+
def _get_cached_replicas(self, keyspace, routing_key_bytes, token_map):
660+
"""
661+
Return cached (token, replicas) for the given keyspace and routing key,
662+
or None on cache miss. The cache is invalidated whenever the token_map
663+
object identity changes (i.e. after a topology rebuild).
664+
"""
665+
if not self._cache_replicas_size:
666+
return None
667+
with self._cache_lock:
668+
if token_map is not self._replica_cache_token_map_ref:
669+
# Token map was rebuilt -- entire cache is stale.
670+
self._replica_cache = OrderedDict()
671+
self._replica_cache_token_map_ref = token_map
672+
cache_key = (keyspace, routing_key_bytes)
673+
entry = self._replica_cache.get(cache_key)
674+
if entry is not None:
675+
# Promote to most-recently-used.
676+
self._replica_cache.move_to_end(cache_key)
677+
return entry
678+
679+
def _put_cached_replicas(self, keyspace, routing_key_bytes, token, replicas, token_map):
680+
"""
681+
Store (token, replicas) in the LRU cache, evicting the oldest
682+
entry if the cache exceeds its configured size.
683+
"""
684+
if not self._cache_replicas_size:
685+
return
686+
with self._cache_lock:
687+
if token_map is not self._replica_cache_token_map_ref:
688+
self._replica_cache = OrderedDict()
689+
self._replica_cache_token_map_ref = token_map
690+
cache_key = (keyspace, routing_key_bytes)
691+
self._replica_cache[cache_key] = (token, replicas)
692+
self._replica_cache.move_to_end(cache_key)
693+
if len(self._replica_cache) > self._cache_replicas_size:
694+
self._replica_cache.popitem(last=False)
695+
648696
def make_query_plan(self, working_keyspace=None, query=None):
649697
keyspace = query.keyspace if query and query.keyspace else working_keyspace
650698

651699
child = self._child_policy
652700
if query is None or query.routing_key is None or keyspace is None:
653-
for host in child.make_query_plan(keyspace, query):
654-
yield host
701+
yield from child.make_query_plan(keyspace, query)
655702
return
656703

704+
cluster_metadata = self._cluster_metadata
705+
token_map = cluster_metadata.token_map
657706
replicas = []
658-
tablet = self._cluster_metadata._tablets.get_tablet_for_key(
659-
keyspace, query.table, self._cluster_metadata.token_map.token_class.from_key(query.routing_key))
660-
661-
if tablet is not None:
662-
replicas_mapped = set(map(lambda r: r[0], tablet.replicas))
663-
child_plan = child.make_query_plan(keyspace, query)
664-
665-
replicas = [host for host in child_plan if host.host_id in replicas_mapped]
707+
if token_map:
708+
try:
709+
token = token_map.token_class.from_key(query.routing_key)
710+
tablet = cluster_metadata._tablets.get_tablet_for_key(
711+
keyspace, query.table, token
712+
)
713+
714+
if tablet is not None:
715+
replicas_mapped = {r[0] for r in tablet.replicas}
716+
child_plan = child.make_query_plan(keyspace, query)
717+
replicas = [host for host in child_plan if host.host_id in replicas_mapped]
718+
else:
719+
cached = self._get_cached_replicas(keyspace, query.routing_key, token_map)
720+
if cached is not None:
721+
token, replicas = cached
722+
else:
723+
try:
724+
replicas = token_map.get_replicas(keyspace, token)
725+
except Exception:
726+
log.debug(
727+
"Failed to get replicas from token_map, "
728+
"falling back to cluster metadata"
729+
)
730+
replicas = cluster_metadata.get_replicas(keyspace, query.routing_key)
731+
self._put_cached_replicas(
732+
keyspace, query.routing_key, token, replicas, token_map
733+
)
734+
except Exception:
735+
log.debug(
736+
"Failed to resolve token or tablet for query plan, "
737+
"falling back to child policy",
738+
exc_info=True,
739+
)
740+
741+
if self.shuffle_replicas:
742+
if not query.is_lwt():
743+
replicas = list(replicas)
744+
shuffle(replicas)
745+
746+
local_rack = []
747+
local = []
748+
remote = []
749+
750+
child_distance = child.distance
751+
752+
for replica in replicas:
753+
if replica.is_up:
754+
d = child_distance(replica)
755+
if d == HostDistance.LOCAL_RACK:
756+
local_rack.append(replica)
757+
elif d == HostDistance.LOCAL:
758+
local.append(replica)
759+
elif d == HostDistance.REMOTE:
760+
remote.append(replica)
761+
762+
if local_rack or local or remote:
763+
yielded = set()
764+
765+
for replica in local_rack:
766+
yielded.add(replica)
767+
yield replica
768+
769+
for replica in local:
770+
yielded.add(replica)
771+
yield replica
772+
773+
for replica in remote:
774+
yielded.add(replica)
775+
yield replica
776+
777+
# Yield the rest of the cluster (non-replica hosts).
778+
# DCAware and RackAware already yield in distance order
779+
# (local_rack -> local -> remote), so we can stream directly.
780+
# For other child policies we must re-sort by distance.
781+
if isinstance(child, (DCAwareRoundRobinPolicy, RackAwareRoundRobinPolicy)):
782+
yield from child.make_query_plan_with_exclusion(keyspace, query, yielded)
783+
else:
784+
remaining_local_rack = []
785+
remaining_local = []
786+
remaining_remote = []
787+
for host in child.make_query_plan_with_exclusion(keyspace, query, yielded):
788+
d = child_distance(host)
789+
if d == HostDistance.LOCAL_RACK:
790+
remaining_local_rack.append(host)
791+
elif d == HostDistance.LOCAL:
792+
remaining_local.append(host)
793+
elif d == HostDistance.REMOTE:
794+
remaining_remote.append(host)
795+
yield from remaining_local_rack
796+
yield from remaining_local
797+
yield from remaining_remote
666798
else:
667-
replicas = self._cluster_metadata.get_replicas(keyspace, query.routing_key)
668-
669-
if self.shuffle_replicas and not query.is_lwt():
670-
shuffle(replicas)
671-
672-
def yield_in_order(hosts):
673-
for distance in [HostDistance.LOCAL_RACK, HostDistance.LOCAL, HostDistance.REMOTE]:
674-
for replica in hosts:
675-
if replica.is_up and child.distance(replica) == distance:
676-
yield replica
677-
678-
# yield replicas: local_rack, local, remote
679-
yield from yield_in_order(replicas)
680-
# yield rest of the cluster: local_rack, local, remote
681-
yield from yield_in_order([host for host in child.make_query_plan(keyspace, query) if host not in replicas])
799+
yield from child.make_query_plan(keyspace, query)
682800

683801
def on_up(self, *args, **kwargs):
684802
return self._child_policy.on_up(*args, **kwargs)

0 commit comments

Comments
 (0)