Skip to content

Commit 5e89ce9

Browse files
committed
(improvement) Cache token-to-replicas lookup in TokenAwarePolicy
Add an LRU cache (OrderedDict-based, default size 1024) to TokenAwarePolicy that avoids repeated token-to-replica lookups for the same (keyspace, routing_key) pair. The cache is automatically invalidated when the token_map object identity changes (topology rebuild), using direct reference comparison (`is not`) instead of `id()` to avoid stale cache hits from id reuse after GC. Set cache_replicas_size=0 to disable. Only the non-tablet code path is cached; the tablet path is unchanged. Thread-safety fixes: - Add `super().__init__()` call to initialize `_hosts_lock` from LoadBalancingPolicy base class. - Add `_cache_lock` (threading.Lock) to protect the OrderedDict-based LRU cache, since `move_to_end()` + `popitem()` sequences are not atomic even under CPython's GIL. - Add `_hosts_lock` and `_cache_lock` to `__slots__`. Includes 7 new unit tests for cache hit, miss (different key/keyspace), topology invalidation, eviction, disabled mode, and tablet bypass. Benchmark (100K queries, 45-node/5-DC topology, Python 3.14, median of 5 runs): Policy | Kops/s | vs master | delta | Mem KB ----------------------------------------------------------------- DCAware | 200 | +89% | | 1.5 RackAware | 167 | +146% | | 2.0 TokenAware(DCAware) | 64 | +256% | -34% | 207.5 TokenAware(RackAware) | 62 | +265% | -30% | 87.1 Default(DCAware) | 142 | +56% | | 1.6 HostFilter(DCAware) | 66 | +25% | | 1.7 Note: The cache shows a regression vs the previous commit in this micro-benchmark because mock get_replicas is O(1). In production with real metadata token ring lookups, the cache amortizes that cost. The cache adds ~87-208 KB memory for 1024 entries. The primary value of this commit is correctness (thread-safety, cache invalidation) and amortized lookup cost for real workloads with repeated partition keys.
1 parent 4d5d665 commit 5e89ce9

2 files changed

Lines changed: 289 additions & 16 deletions

File tree

cassandra/policies.py

Lines changed: 81 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
import random
1515

16-
from collections import namedtuple
16+
from collections import namedtuple, OrderedDict
1717
from itertools import islice, cycle, groupby, repeat
1818
import logging
1919
from random import randint, shuffle
@@ -635,14 +635,33 @@ class TokenAwarePolicy(LoadBalancingPolicy):
635635
636636
If no :attr:`~.Statement.routing_key` is set on the query, the child
637637
policy's query plan will be used as is.
638-
"""
639-
640-
__slots__ = ("_child_policy", "_cluster_metadata", "shuffle_replicas")
641638
642-
def __init__(self, child_policy, shuffle_replicas=True):
639+
An LRU cache of size :attr:`cache_replicas_size` (default 1024) avoids
640+
repeated token-to-replica lookups for the same (keyspace, routing_key)
641+
pair. Set to 0 to disable caching. The cache is automatically
642+
invalidated when the cluster topology changes.
643+
"""
644+
645+
__slots__ = (
646+
"_child_policy",
647+
"_cluster_metadata",
648+
"shuffle_replicas",
649+
"_replica_cache",
650+
"_replica_cache_token_map_ref",
651+
"_cache_replicas_size",
652+
"_hosts_lock",
653+
"_cache_lock",
654+
)
655+
656+
def __init__(self, child_policy, shuffle_replicas=True, cache_replicas_size=1024):
657+
super().__init__()
643658
self._child_policy = child_policy
644659
self.shuffle_replicas = shuffle_replicas
645660
self._cluster_metadata = None
661+
self._cache_replicas_size = max(0, cache_replicas_size)
662+
self._replica_cache = OrderedDict()
663+
self._replica_cache_token_map_ref = None
664+
self._cache_lock = Lock()
646665

647666
def populate(self, cluster, hosts):
648667
self._cluster_metadata = cluster.metadata
@@ -661,6 +680,45 @@ def check_supported(self):
661680
def distance(self, *args, **kwargs):
662681
return self._child_policy.distance(*args, **kwargs)
663682

683+
def _get_cached_replicas(self, keyspace, routing_key_bytes, token_map):
684+
"""
685+
Return cached (token, replicas) for the given keyspace and routing key,
686+
or None on cache miss. The cache is invalidated whenever the token_map
687+
object identity changes (i.e. after a topology rebuild).
688+
"""
689+
if not self._cache_replicas_size:
690+
return None
691+
with self._cache_lock:
692+
if token_map is not self._replica_cache_token_map_ref:
693+
# Token map was rebuilt -- entire cache is stale.
694+
self._replica_cache = OrderedDict()
695+
self._replica_cache_token_map_ref = token_map
696+
cache_key = (keyspace, routing_key_bytes)
697+
entry = self._replica_cache.get(cache_key)
698+
if entry is not None:
699+
# Promote to most-recently-used.
700+
self._replica_cache.move_to_end(cache_key)
701+
return entry
702+
703+
def _put_cached_replicas(
704+
self, keyspace, routing_key_bytes, token, replicas, token_map
705+
):
706+
"""
707+
Store (token, replicas) in the LRU cache, evicting the oldest
708+
entry if the cache exceeds its configured size.
709+
"""
710+
if not self._cache_replicas_size:
711+
return
712+
with self._cache_lock:
713+
if token_map is not self._replica_cache_token_map_ref:
714+
self._replica_cache = OrderedDict()
715+
self._replica_cache_token_map_ref = token_map
716+
cache_key = (keyspace, routing_key_bytes)
717+
self._replica_cache[cache_key] = (token, replicas)
718+
self._replica_cache.move_to_end(cache_key)
719+
if len(self._replica_cache) > self._cache_replicas_size:
720+
self._replica_cache.popitem(last=False)
721+
664722
def make_query_plan(self, working_keyspace=None, query=None):
665723
keyspace = query.keyspace if query and query.keyspace else working_keyspace
666724

@@ -686,14 +744,24 @@ def make_query_plan(self, working_keyspace=None, query=None):
686744
host for host in child_plan if host.host_id in replicas_mapped
687745
]
688746
else:
689-
try:
690-
replicas = token_map.get_replicas(keyspace, token)
691-
except Exception:
692-
log.debug(
693-
"Failed to get replicas from token_map, falling back to cluster metadata"
694-
)
695-
replicas = cluster_metadata.get_replicas(
696-
keyspace, query.routing_key
747+
cached = self._get_cached_replicas(
748+
keyspace, query.routing_key, token_map
749+
)
750+
if cached is not None:
751+
token, replicas = cached
752+
else:
753+
try:
754+
replicas = token_map.get_replicas(keyspace, token)
755+
except Exception:
756+
log.debug(
757+
"Failed to get replicas from token_map, "
758+
"falling back to cluster metadata"
759+
)
760+
replicas = cluster_metadata.get_replicas(
761+
keyspace, query.routing_key
762+
)
763+
self._put_cached_replicas(
764+
keyspace, query.routing_key, token, replicas, token_map
697765
)
698766
except Exception:
699767
log.debug(

tests/unit/test_policies.py

Lines changed: 208 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,9 +1088,9 @@ def test_statement_keyspace(self):
10881088
query = Statement(routing_key=routing_key, keyspace=statement_keyspace)
10891089
qplan = list(policy.make_query_plan(working_keyspace, query))
10901090
assert replicas + hosts[:2] == qplan
1091-
cluster.metadata.get_replicas.assert_called_with(
1092-
statement_keyspace, routing_key
1093-
)
1091+
# get_replicas may not be called here due to cache hit from the
1092+
# previous query with the same (statement_keyspace, routing_key) pair.
1093+
# The important assertion is that the plan result is correct above.
10941094

10951095
def test_shuffles_if_given_keyspace_and_routing_key(self):
10961096
"""
@@ -1240,6 +1240,211 @@ def _assert_shuffle(self, patched_shuffle, cluster, keyspace, routing_key):
12401240
child_policy.make_query_plan.assert_called_once_with(keyspace, query)
12411241
assert patched_shuffle.call_count == 1
12421242

1243+
# --- Replica cache tests ---
1244+
1245+
def _make_cache_cluster(self):
1246+
"""Create a mock cluster suitable for cache tests."""
1247+
hosts = [
1248+
Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy, host_id=uuid.uuid4())
1249+
for i in range(4)
1250+
]
1251+
for host in hosts:
1252+
host.set_up()
1253+
cluster = Mock(spec=Cluster)
1254+
cluster.metadata = Mock(spec=Metadata)
1255+
cluster.metadata._tablets = Mock(spec=Tablets)
1256+
cluster.metadata._tablets.get_tablet_for_key.return_value = None
1257+
cluster.metadata.token_map = Mock()
1258+
cluster.metadata.token_map.token_class.from_key.side_effect = lambda key: key
1259+
cluster.metadata.token_map.get_replicas.return_value = hosts[2:]
1260+
return cluster, hosts
1261+
1262+
def test_cache_hit(self):
1263+
"""Same (keyspace, routing_key) should only call get_replicas once."""
1264+
cluster, hosts = self._make_cache_cluster()
1265+
1266+
child_policy = Mock()
1267+
child_policy.make_query_plan.return_value = hosts
1268+
child_policy.make_query_plan_with_exclusion.side_effect = lambda k, q, e: [
1269+
h for h in hosts if h not in e
1270+
]
1271+
child_policy.distance.return_value = HostDistance.LOCAL
1272+
1273+
policy = TokenAwarePolicy(child_policy, shuffle_replicas=False)
1274+
policy.populate(cluster, hosts)
1275+
1276+
query = Statement(routing_key=b"key1", keyspace="ks")
1277+
list(policy.make_query_plan(None, query))
1278+
list(policy.make_query_plan(None, query))
1279+
1280+
assert cluster.metadata.token_map.get_replicas.call_count == 1
1281+
1282+
def test_cache_miss_different_key(self):
1283+
"""Different routing_key should cause separate get_replicas calls."""
1284+
cluster, hosts = self._make_cache_cluster()
1285+
1286+
child_policy = Mock()
1287+
child_policy.make_query_plan.return_value = hosts
1288+
child_policy.make_query_plan_with_exclusion.side_effect = lambda k, q, e: [
1289+
h for h in hosts if h not in e
1290+
]
1291+
child_policy.distance.return_value = HostDistance.LOCAL
1292+
1293+
policy = TokenAwarePolicy(child_policy, shuffle_replicas=False)
1294+
policy.populate(cluster, hosts)
1295+
1296+
q1 = Statement(routing_key=b"key1", keyspace="ks")
1297+
q2 = Statement(routing_key=b"key2", keyspace="ks")
1298+
list(policy.make_query_plan(None, q1))
1299+
list(policy.make_query_plan(None, q2))
1300+
1301+
assert cluster.metadata.token_map.get_replicas.call_count == 2
1302+
1303+
def test_cache_miss_different_keyspace(self):
1304+
"""Different keyspace with same routing_key should miss cache."""
1305+
cluster, hosts = self._make_cache_cluster()
1306+
1307+
child_policy = Mock()
1308+
child_policy.make_query_plan.return_value = hosts
1309+
child_policy.make_query_plan_with_exclusion.side_effect = lambda k, q, e: [
1310+
h for h in hosts if h not in e
1311+
]
1312+
child_policy.distance.return_value = HostDistance.LOCAL
1313+
1314+
policy = TokenAwarePolicy(child_policy, shuffle_replicas=False)
1315+
policy.populate(cluster, hosts)
1316+
1317+
q1 = Statement(routing_key=b"key1", keyspace="ks1")
1318+
q2 = Statement(routing_key=b"key1", keyspace="ks2")
1319+
list(policy.make_query_plan(None, q1))
1320+
list(policy.make_query_plan(None, q2))
1321+
1322+
assert cluster.metadata.token_map.get_replicas.call_count == 2
1323+
1324+
def test_cache_invalidation_on_topology_change(self):
1325+
"""Cache should be invalidated when token_map object changes."""
1326+
cluster, hosts = self._make_cache_cluster()
1327+
1328+
child_policy = Mock()
1329+
child_policy.make_query_plan.return_value = hosts
1330+
child_policy.make_query_plan_with_exclusion.side_effect = lambda k, q, e: [
1331+
h for h in hosts if h not in e
1332+
]
1333+
child_policy.distance.return_value = HostDistance.LOCAL
1334+
1335+
policy = TokenAwarePolicy(child_policy, shuffle_replicas=False)
1336+
policy.populate(cluster, hosts)
1337+
1338+
query = Statement(routing_key=b"key1", keyspace="ks")
1339+
list(policy.make_query_plan(None, query))
1340+
assert cluster.metadata.token_map.get_replicas.call_count == 1
1341+
1342+
# Simulate topology change: replace token_map with a new mock object
1343+
new_token_map = Mock()
1344+
new_token_map.token_class.from_key.side_effect = lambda key: key
1345+
new_token_map.get_replicas.return_value = hosts[2:]
1346+
cluster.metadata.token_map = new_token_map
1347+
1348+
list(policy.make_query_plan(None, query))
1349+
# The old token_map still has 1 call; new one should have 1 call
1350+
assert new_token_map.get_replicas.call_count == 1
1351+
1352+
def test_cache_eviction(self):
1353+
"""Oldest entries should be evicted when cache exceeds size."""
1354+
cluster, hosts = self._make_cache_cluster()
1355+
1356+
child_policy = Mock()
1357+
child_policy.make_query_plan.return_value = hosts
1358+
child_policy.make_query_plan_with_exclusion.side_effect = lambda k, q, e: [
1359+
h for h in hosts if h not in e
1360+
]
1361+
child_policy.distance.return_value = HostDistance.LOCAL
1362+
1363+
policy = TokenAwarePolicy(
1364+
child_policy, shuffle_replicas=False, cache_replicas_size=2
1365+
)
1366+
policy.populate(cluster, hosts)
1367+
1368+
# Fill cache with 3 entries; size=2 so first should be evicted
1369+
for i in range(3):
1370+
q = Statement(routing_key=f"key{i}".encode(), keyspace="ks")
1371+
list(policy.make_query_plan(None, q))
1372+
1373+
assert cluster.metadata.token_map.get_replicas.call_count == 3
1374+
1375+
# key2 (most recent) should be cached
1376+
cluster.metadata.token_map.get_replicas.reset_mock()
1377+
q = Statement(routing_key=b"key2", keyspace="ks")
1378+
list(policy.make_query_plan(None, q))
1379+
assert cluster.metadata.token_map.get_replicas.call_count == 0
1380+
1381+
# key0 (evicted) should miss
1382+
q = Statement(routing_key=b"key0", keyspace="ks")
1383+
list(policy.make_query_plan(None, q))
1384+
assert cluster.metadata.token_map.get_replicas.call_count == 1
1385+
1386+
def test_cache_disabled(self):
1387+
"""cache_replicas_size=0 should bypass caching entirely."""
1388+
cluster, hosts = self._make_cache_cluster()
1389+
1390+
child_policy = Mock()
1391+
child_policy.make_query_plan.return_value = hosts
1392+
child_policy.make_query_plan_with_exclusion.side_effect = lambda k, q, e: [
1393+
h for h in hosts if h not in e
1394+
]
1395+
child_policy.distance.return_value = HostDistance.LOCAL
1396+
1397+
policy = TokenAwarePolicy(
1398+
child_policy, shuffle_replicas=False, cache_replicas_size=0
1399+
)
1400+
policy.populate(cluster, hosts)
1401+
1402+
query = Statement(routing_key=b"key1", keyspace="ks")
1403+
list(policy.make_query_plan(None, query))
1404+
list(policy.make_query_plan(None, query))
1405+
list(policy.make_query_plan(None, query))
1406+
1407+
# Every call should reach get_replicas
1408+
assert cluster.metadata.token_map.get_replicas.call_count == 3
1409+
1410+
def test_tablet_path_not_cached(self):
1411+
"""Tablet path should bypass the cache entirely."""
1412+
hosts = [
1413+
Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy, host_id=uuid.uuid4())
1414+
for i in range(4)
1415+
]
1416+
for host in hosts:
1417+
host.set_up()
1418+
1419+
cluster = Mock(spec=Cluster)
1420+
cluster.metadata = Mock(spec=Metadata)
1421+
cluster.metadata._tablets = Mock(spec=Tablets)
1422+
cluster.metadata._tablets.get_tablet_for_key.return_value = Tablet(
1423+
replicas=[(h.host_id, 0) for h in hosts[2:]]
1424+
)
1425+
cluster.metadata.token_map = Mock()
1426+
cluster.metadata.token_map.token_class.from_key.side_effect = lambda key: key
1427+
cluster.metadata.token_map.get_replicas.return_value = hosts[2:]
1428+
1429+
child_policy = Mock()
1430+
child_policy.make_query_plan.return_value = hosts
1431+
child_policy.make_query_plan_with_exclusion.side_effect = lambda k, q, e: [
1432+
h for h in hosts if h not in e
1433+
]
1434+
child_policy.distance.return_value = HostDistance.LOCAL
1435+
1436+
policy = TokenAwarePolicy(child_policy, shuffle_replicas=False)
1437+
policy.populate(cluster, hosts)
1438+
1439+
query = Statement(routing_key=b"key1", keyspace="ks")
1440+
list(policy.make_query_plan(None, query))
1441+
list(policy.make_query_plan(None, query))
1442+
1443+
# token_map.get_replicas should NOT be called (tablet path used)
1444+
assert cluster.metadata.token_map.get_replicas.call_count == 0
1445+
# Cache should remain empty (tablet results are not cached)
1446+
assert len(policy._replica_cache) == 0
1447+
12431448

12441449
class ConvictionPolicyTest(unittest.TestCase):
12451450
def test_not_implemented(self):

0 commit comments

Comments
 (0)