Skip to content

Commit 3f1cda2

Browse files
committed
perf: cache _replica_dict on Tablet for O(1) host/shard lookup
Build a {host_id: shard_id} dict once at Tablet construction time so that policies.py and pool.py can replace set(map(lambda ...)) and linear scans with O(1) dict operations. - Add _replica_dict to __slots__ - Build dict from the materialized tuple (not the raw replicas arg) to avoid double-consuming a one-shot iterator - Update DCAwareRoundRobinPolicy to use tablet._replica_dict keys - Update HostConnection to use tablet._replica_dict.get() for shard - Rewrite replica_contains_host_id() to use dict membership - Add 7 unit tests covering dict construction, lookup, host membership, tuple storage, and the iterator edge case
1 parent 1cdc668 commit 3f1cda2

4 files changed

Lines changed: 72 additions & 12 deletions

File tree

cassandra/policies.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -507,10 +507,10 @@ def make_query_plan(self, working_keyspace=None, query=None):
507507
keyspace, query.table, self._cluster_metadata.token_map.token_class.from_key(query.routing_key))
508508

509509
if tablet is not None:
510-
replicas_mapped = set(map(lambda r: r[0], tablet.replicas))
510+
replica_dict = tablet._replica_dict
511511
child_plan = child.make_query_plan(keyspace, query)
512512

513-
replicas = [host for host in child_plan if host.host_id in replicas_mapped]
513+
replicas = [host for host in child_plan if host.host_id in replica_dict]
514514
else:
515515
replicas = self._cluster_metadata.get_replicas(keyspace, query.routing_key)
516516

cassandra/pool.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -462,10 +462,7 @@ def _get_connection_for_routing_key(self, routing_key=None, keyspace=None, table
462462
tablet = self._session.cluster.metadata._tablets.get_tablet_for_key(keyspace, table, t)
463463

464464
if tablet is not None:
465-
for replica in tablet.replicas:
466-
if replica[0] == self.host.host_id:
467-
shard_id = replica[1]
468-
break
465+
shard_id = tablet._replica_dict.get(self.host.host_id)
469466

470467
if shard_id is None:
471468
shard_id = self.host.sharding_info.shard_id_from_token(t.value)

cassandra/tablets.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,18 @@ class Tablet(object):
1515
It stores information about each replica, its host and shard,
1616
and the token interval in the format (first_token, last_token].
1717
"""
18-
__slots__ = ('first_token', 'last_token', 'replicas')
18+
__slots__ = ('first_token', 'last_token', 'replicas', '_replica_dict')
1919

2020
def __init__(self, first_token=0, last_token=0, replicas=None):
2121
self.first_token = first_token
2222
self.last_token = last_token
23-
self.replicas = tuple(replicas) if replicas is not None else None
23+
if replicas is not None:
24+
replicas_tuple = tuple(replicas)
25+
self.replicas = replicas_tuple
26+
self._replica_dict = {r[0]: r[1] for r in replicas_tuple}
27+
else:
28+
self.replicas = None
29+
self._replica_dict = {}
2430

2531
def __str__(self):
2632
return "<Tablet: first_token=%s last_token=%s replicas=%s>" \
@@ -39,10 +45,7 @@ def from_row(first_token, last_token, replicas):
3945
return None
4046

4147
def replica_contains_host_id(self, uuid: UUID) -> bool:
42-
for replica in self.replicas:
43-
if replica[0] == uuid:
44-
return True
45-
return False
48+
return uuid in self._replica_dict
4649

4750

4851
class Tablets(object):

tests/unit/test_tablets.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import unittest
2+
from uuid import UUID
23

34
from cassandra.tablets import Tablets, Tablet
45

@@ -124,3 +125,62 @@ def __init__(self, v):
124125
# Token value 50 is not > first_token (100) of the tablet whose
125126
# last_token (200) is >= 50, so no match.
126127
self.assertIsNone(tablets.get_tablet_for_key("ks", "tb", Token(50)))
128+
129+
130+
class TabletReplicaDictTest(unittest.TestCase):
131+
"""Tests for Tablet._replica_dict cached lookup."""
132+
133+
def test_replica_dict_built_from_replicas(self):
134+
u1 = UUID('12345678-1234-5678-1234-567812345678')
135+
u2 = UUID('87654321-4321-8765-4321-876543218765')
136+
t = Tablet(0, 100, [(u1, 3), (u2, 7)])
137+
self.assertEqual(t._replica_dict, {u1: 3, u2: 7})
138+
139+
def test_replica_dict_empty_when_no_replicas(self):
140+
t = Tablet(0, 100, None)
141+
self.assertEqual(t._replica_dict, {})
142+
143+
def test_replica_dict_contains_host(self):
144+
u1 = UUID('12345678-1234-5678-1234-567812345678')
145+
u2 = UUID('87654321-4321-8765-4321-876543218765')
146+
u3 = UUID('aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee')
147+
t = Tablet(0, 100, [(u1, 3), (u2, 7)])
148+
self.assertIn(u1, t._replica_dict)
149+
self.assertIn(u2, t._replica_dict)
150+
self.assertNotIn(u3, t._replica_dict)
151+
152+
def test_replica_dict_shard_lookup(self):
153+
u1 = UUID('12345678-1234-5678-1234-567812345678')
154+
u2 = UUID('87654321-4321-8765-4321-876543218765')
155+
t = Tablet(0, 100, [(u1, 3), (u2, 7)])
156+
self.assertEqual(t._replica_dict.get(u1), 3)
157+
self.assertEqual(t._replica_dict.get(u2), 7)
158+
self.assertIsNone(t._replica_dict.get(UUID('aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee')))
159+
160+
def test_replica_contains_host_id_uses_dict(self):
161+
u1 = UUID('12345678-1234-5678-1234-567812345678')
162+
u2 = UUID('87654321-4321-8765-4321-876543218765')
163+
t = Tablet(0, 100, [(u1, 3), (u2, 7)])
164+
self.assertTrue(t.replica_contains_host_id(u1))
165+
self.assertTrue(t.replica_contains_host_id(u2))
166+
self.assertFalse(t.replica_contains_host_id(UUID('aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee')))
167+
168+
def test_replicas_stored_as_tuple(self):
169+
t = Tablet(0, 100, [("host1", 0), ("host2", 1)])
170+
self.assertIsInstance(t.replicas, tuple)
171+
172+
def test_replica_dict_from_iterator(self):
173+
"""Ensure _replica_dict is correctly built even when replicas is a
174+
one-shot iterator (generator), not a reusable list."""
175+
u1 = UUID('12345678-1234-5678-1234-567812345678')
176+
u2 = UUID('87654321-4321-8765-4321-876543218765')
177+
178+
def gen():
179+
yield (u1, 3)
180+
yield (u2, 7)
181+
182+
t = Tablet(0, 100, gen())
183+
self.assertEqual(t.replicas, ((u1, 3), (u2, 7)))
184+
self.assertEqual(t._replica_dict, {u1: 3, u2: 7})
185+
self.assertTrue(t.replica_contains_host_id(u1))
186+
self.assertTrue(t.replica_contains_host_id(u2))

0 commit comments

Comments
 (0)