@@ -616,6 +616,11 @@ class TokenAwarePolicy(LoadBalancingPolicy):
616616
617617 If no :attr:`~.Statement.routing_key` is set on the query, the child
618618 policy's query plan will be used as is.
619+
620+ An LRU cache of size :attr:`cache_replicas_size` (default 1024) avoids
621+ repeated token-to-replica lookups for the same (keyspace, routing_key)
622+ pair. Set to 0 to disable caching. The cache is automatically
623+ invalidated when the cluster topology changes.
619624 """
620625
621626 _child_policy = None
@@ -625,9 +630,15 @@ class TokenAwarePolicy(LoadBalancingPolicy):
625630 Yield local replicas in a random order.
626631 """
627632
628- def __init__ (self , child_policy , shuffle_replicas = True ):
633+ def __init__ (self , child_policy , shuffle_replicas = True , cache_replicas_size = 1024 ):
634+ super ().__init__ ()
629635 self ._child_policy = child_policy
630636 self .shuffle_replicas = shuffle_replicas
637+ self ._cluster_metadata = None
638+ self ._cache_replicas_size = max (0 , cache_replicas_size )
639+ self ._replica_cache = OrderedDict ()
640+ self ._replica_cache_token_map_ref = None
641+ self ._cache_lock = Lock ()
631642
632643 def populate (self , cluster , hosts ):
633644 self ._cluster_metadata = cluster .metadata
@@ -645,40 +656,147 @@ def check_supported(self):
645656 def distance (self , * args , ** kwargs ):
646657 return self ._child_policy .distance (* args , ** kwargs )
647658
659+ def _get_cached_replicas (self , keyspace , routing_key_bytes , token_map ):
660+ """
661+ Return cached (token, replicas) for the given keyspace and routing key,
662+ or None on cache miss. The cache is invalidated whenever the token_map
663+ object identity changes (i.e. after a topology rebuild).
664+ """
665+ if not self ._cache_replicas_size :
666+ return None
667+ with self ._cache_lock :
668+ if token_map is not self ._replica_cache_token_map_ref :
669+ # Token map was rebuilt -- entire cache is stale.
670+ self ._replica_cache = OrderedDict ()
671+ self ._replica_cache_token_map_ref = token_map
672+ cache_key = (keyspace , routing_key_bytes )
673+ entry = self ._replica_cache .get (cache_key )
674+ if entry is not None :
675+ # Promote to most-recently-used.
676+ self ._replica_cache .move_to_end (cache_key )
677+ return entry
678+
679+ def _put_cached_replicas (self , keyspace , routing_key_bytes , token , replicas , token_map ):
680+ """
681+ Store (token, replicas) in the LRU cache, evicting the oldest
682+ entry if the cache exceeds its configured size.
683+ """
684+ if not self ._cache_replicas_size :
685+ return
686+ with self ._cache_lock :
687+ if token_map is not self ._replica_cache_token_map_ref :
688+ self ._replica_cache = OrderedDict ()
689+ self ._replica_cache_token_map_ref = token_map
690+ cache_key = (keyspace , routing_key_bytes )
691+ self ._replica_cache [cache_key ] = (token , replicas )
692+ self ._replica_cache .move_to_end (cache_key )
693+ if len (self ._replica_cache ) > self ._cache_replicas_size :
694+ self ._replica_cache .popitem (last = False )
695+
648696 def make_query_plan (self , working_keyspace = None , query = None ):
649697 keyspace = query .keyspace if query and query .keyspace else working_keyspace
650698
651699 child = self ._child_policy
652700 if query is None or query .routing_key is None or keyspace is None :
653- for host in child .make_query_plan (keyspace , query ):
654- yield host
701+ yield from child .make_query_plan (keyspace , query )
655702 return
656703
704+ cluster_metadata = self ._cluster_metadata
705+ token_map = cluster_metadata .token_map
657706 replicas = []
658- tablet = self ._cluster_metadata ._tablets .get_tablet_for_key (
659- keyspace , query .table , self ._cluster_metadata .token_map .token_class .from_key (query .routing_key ))
660-
661- if tablet is not None :
662- replicas_mapped = set (map (lambda r : r [0 ], tablet .replicas ))
663- child_plan = child .make_query_plan (keyspace , query )
664-
665- replicas = [host for host in child_plan if host .host_id in replicas_mapped ]
707+ if token_map :
708+ try :
709+ token = token_map .token_class .from_key (query .routing_key )
710+ tablet = cluster_metadata ._tablets .get_tablet_for_key (
711+ keyspace , query .table , token
712+ )
713+
714+ if tablet is not None :
715+ replicas_mapped = {r [0 ] for r in tablet .replicas }
716+ child_plan = child .make_query_plan (keyspace , query )
717+ replicas = [host for host in child_plan if host .host_id in replicas_mapped ]
718+ else :
719+ cached = self ._get_cached_replicas (keyspace , query .routing_key , token_map )
720+ if cached is not None :
721+ token , replicas = cached
722+ else :
723+ try :
724+ replicas = token_map .get_replicas (keyspace , token )
725+ except Exception :
726+ log .debug (
727+ "Failed to get replicas from token_map, "
728+ "falling back to cluster metadata"
729+ )
730+ replicas = cluster_metadata .get_replicas (keyspace , query .routing_key )
731+ self ._put_cached_replicas (
732+ keyspace , query .routing_key , token , replicas , token_map
733+ )
734+ except Exception :
735+ log .debug (
736+ "Failed to resolve token or tablet for query plan, "
737+ "falling back to child policy" ,
738+ exc_info = True ,
739+ )
740+
741+ if self .shuffle_replicas :
742+ if not query .is_lwt ():
743+ replicas = list (replicas )
744+ shuffle (replicas )
745+
746+ local_rack = []
747+ local = []
748+ remote = []
749+
750+ child_distance = child .distance
751+
752+ for replica in replicas :
753+ if replica .is_up :
754+ d = child_distance (replica )
755+ if d == HostDistance .LOCAL_RACK :
756+ local_rack .append (replica )
757+ elif d == HostDistance .LOCAL :
758+ local .append (replica )
759+ elif d == HostDistance .REMOTE :
760+ remote .append (replica )
761+
762+ if local_rack or local or remote :
763+ yielded = set ()
764+
765+ for replica in local_rack :
766+ yielded .add (replica )
767+ yield replica
768+
769+ for replica in local :
770+ yielded .add (replica )
771+ yield replica
772+
773+ for replica in remote :
774+ yielded .add (replica )
775+ yield replica
776+
777+ # Yield the rest of the cluster (non-replica hosts).
778+ # DCAware and RackAware already yield in distance order
779+ # (local_rack -> local -> remote), so we can stream directly.
780+ # For other child policies we must re-sort by distance.
781+ if isinstance (child , (DCAwareRoundRobinPolicy , RackAwareRoundRobinPolicy )):
782+ yield from child .make_query_plan_with_exclusion (keyspace , query , yielded )
783+ else :
784+ remaining_local_rack = []
785+ remaining_local = []
786+ remaining_remote = []
787+ for host in child .make_query_plan_with_exclusion (keyspace , query , yielded ):
788+ d = child_distance (host )
789+ if d == HostDistance .LOCAL_RACK :
790+ remaining_local_rack .append (host )
791+ elif d == HostDistance .LOCAL :
792+ remaining_local .append (host )
793+ elif d == HostDistance .REMOTE :
794+ remaining_remote .append (host )
795+ yield from remaining_local_rack
796+ yield from remaining_local
797+ yield from remaining_remote
666798 else :
667- replicas = self ._cluster_metadata .get_replicas (keyspace , query .routing_key )
668-
669- if self .shuffle_replicas and not query .is_lwt ():
670- shuffle (replicas )
671-
672- def yield_in_order (hosts ):
673- for distance in [HostDistance .LOCAL_RACK , HostDistance .LOCAL , HostDistance .REMOTE ]:
674- for replica in hosts :
675- if replica .is_up and child .distance (replica ) == distance :
676- yield replica
677-
678- # yield replicas: local_rack, local, remote
679- yield from yield_in_order (replicas )
680- # yield rest of the cluster: local_rack, local, remote
681- yield from yield_in_order ([host for host in child .make_query_plan (keyspace , query ) if host not in replicas ])
799+ yield from child .make_query_plan (keyspace , query )
682800
683801 def on_up (self , * args , ** kwargs ):
684802 return self ._child_policy .on_up (* args , ** kwargs )
0 commit comments