Skip to content

Commit ee04aa9

Browse files
committed
perf: optimize TokenAwarePolicy cache -- check before hash, add keyspace-aware invalidation
- Move LRU cache lookup before token_class.from_key() so cache hits skip the murmur3 hash computation and Token object allocation entirely. - Add keyspace-aware cache invalidation: track the per-keyspace replica map object identity so ALTER KEYSPACE / replication changes are detected even when the TokenMap object itself is reused (in-place rebuild). - Remove unused 'token' from cache entries (was never read after storage). - Add test_cache_invalidation_on_keyspace_replication_change. TODO: The tablet path still does two full child-policy traversals per query. Metadata.get_host_by_host_id() is O(1) and could resolve tablet replicas in O(rf) instead. Deferred to minimize behavioral change.
1 parent dfc2a36 commit ee04aa9

8 files changed

Lines changed: 1048 additions & 30 deletions

File tree

.opencode/package-lock.json

Lines changed: 115 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

bench_ab.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Benchmark: recv_results_rows — origin/master vs new Cython metadata parser.
4+
5+
Compares the ACTUAL Cython recv_results_rows path (FastResultMessage)
6+
as it exists in the currently-built code.
7+
8+
Run with: taskset -c 0 python3 bench_ab.py
9+
"""
10+
11+
import struct
12+
import io
13+
import time
14+
import sys
15+
import uuid
16+
17+
18+
def write_short(buf, v):
19+
buf.write(struct.pack('>H', v))
20+
21+
def write_int(buf, v):
22+
buf.write(struct.pack('>i', v))
23+
24+
def write_string(buf, s):
25+
if isinstance(s, str):
26+
s = s.encode('utf8')
27+
write_short(buf, len(s))
28+
buf.write(s)
29+
30+
def write_type(buf, type_code, subtypes=()):
31+
write_short(buf, type_code)
32+
for st in subtypes:
33+
if isinstance(st, tuple):
34+
write_type(buf, st[0], st[1:])
35+
else:
36+
write_type(buf, st)
37+
38+
39+
UUID_TYPE = 0x000C
40+
VARCHAR_TYPE = 0x000D
41+
INT_TYPE = 0x0009
42+
BIGINT_TYPE = 0x0002
43+
BOOLEAN_TYPE = 0x0004
44+
DOUBLE_TYPE = 0x0007
45+
TIMESTAMP_TYPE = 0x000B
46+
LIST_TYPE = 0x0020
47+
MAP_TYPE = 0x0021
48+
SET_TYPE = 0x0022
49+
50+
51+
def build_rows_message(colcount, type_codes_list, nrows=0):
52+
buf = io.BytesIO()
53+
write_int(buf, 0x0001) # GLOBAL_TABLES_SPEC
54+
write_int(buf, colcount)
55+
write_string(buf, 'test_ks')
56+
write_string(buf, 'test_cf')
57+
for i in range(colcount):
58+
write_string(buf, 'col_%d' % i)
59+
tc = type_codes_list[i % len(type_codes_list)]
60+
if isinstance(tc, tuple):
61+
write_type(buf, tc[0], tc[1:])
62+
else:
63+
write_type(buf, tc)
64+
write_int(buf, nrows)
65+
for _ in range(nrows):
66+
for i in range(colcount):
67+
tc = type_codes_list[i % len(type_codes_list)]
68+
base_tc = tc[0] if isinstance(tc, tuple) else tc
69+
if base_tc == UUID_TYPE:
70+
write_int(buf, 16); buf.write(uuid.uuid4().bytes)
71+
elif base_tc == VARCHAR_TYPE:
72+
v = b'test_value'; write_int(buf, len(v)); buf.write(v)
73+
elif base_tc == INT_TYPE:
74+
write_int(buf, 4); buf.write(struct.pack('>i', 42))
75+
elif base_tc in (BIGINT_TYPE, TIMESTAMP_TYPE, DOUBLE_TYPE):
76+
write_int(buf, 8); buf.write(struct.pack('>q', 12345678))
77+
elif base_tc == BOOLEAN_TYPE:
78+
write_int(buf, 1); buf.write(b'\x01')
79+
elif base_tc in (LIST_TYPE, SET_TYPE, MAP_TYPE):
80+
write_int(buf, 4); buf.write(struct.pack('>i', 0))
81+
else:
82+
write_int(buf, 4); buf.write(b'\x00\x00\x00\x00')
83+
return buf.getvalue()
84+
85+
86+
def build_no_metadata_message(colcount=10):
87+
buf = io.BytesIO()
88+
write_int(buf, 0x0004) # NO_METADATA
89+
write_int(buf, colcount)
90+
write_int(buf, 0) # 0 rows
91+
return buf.getvalue()
92+
93+
94+
def bench(label, fn, iterations, warmup=1000):
95+
for _ in range(warmup):
96+
fn()
97+
times = []
98+
for _ in range(iterations):
99+
t0 = time.perf_counter_ns()
100+
fn()
101+
t1 = time.perf_counter_ns()
102+
times.append(t1 - t0)
103+
times.sort()
104+
trim = max(1, len(times) // 20)
105+
trimmed = times[trim:-trim]
106+
mean_ns = sum(trimmed) / len(trimmed)
107+
var = sum((t - mean_ns)**2 for t in trimmed) / len(trimmed)
108+
cv = (var**0.5 / mean_ns * 100) if mean_ns else 0
109+
print(f" {label:50s} {mean_ns:9.0f} ns (cv {cv:4.1f}%)")
110+
return mean_ns
111+
112+
113+
def main():
114+
from cassandra.protocol import ProtocolHandler
115+
from cassandra.cython_deps import HAVE_CYTHON
116+
117+
print(f"HAVE_CYTHON: {HAVE_CYTHON}")
118+
print(f"Python: {sys.version}")
119+
120+
fast_cls = ProtocolHandler.message_types_by_opcode[0x08]
121+
print(f"FastResultMessage: {fast_cls}")
122+
print()
123+
124+
simple_types = [UUID_TYPE, VARCHAR_TYPE, INT_TYPE, BIGINT_TYPE, BOOLEAN_TYPE,
125+
DOUBLE_TYPE, TIMESTAMP_TYPE, VARCHAR_TYPE, INT_TYPE, UUID_TYPE]
126+
127+
scenarios = [
128+
("10 cols, 0 rows", 10, simple_types, 0, 10000),
129+
("3 cols, 0 rows", 3, simple_types[:3], 0, 10000),
130+
("50 cols, 0 rows", 50, simple_types, 0, 5000),
131+
("10 cols, 10 rows", 10, simple_types, 10, 5000),
132+
("10 cols, 100 rows", 10, simple_types, 100, 2000),
133+
("10 cols, 1000 rows", 10, simple_types, 1000, 500),
134+
]
135+
136+
results = {}
137+
for desc, colcount, types, nrows, iters in scenarios:
138+
data = build_rows_message(colcount, types, nrows)
139+
print(f"--- {desc} ({len(data)} bytes) ---")
140+
141+
def fn(data=data):
142+
f = io.BytesIO(data)
143+
msg = fast_cls(2)
144+
msg.recv_results_rows(f, 4, {}, None, None)
145+
146+
t = bench("Cython recv_results_rows", fn, iters)
147+
results[desc] = t
148+
print()
149+
150+
# NO_METADATA with result_metadata
151+
from cassandra.cqltypes import (UUIDType, VarcharType, Int32Type, LongType,
152+
BooleanType, DoubleType, DateType)
153+
result_md = [
154+
('ks', 'cf', 'c%d' % i, [UUIDType, VarcharType, Int32Type, LongType,
155+
BooleanType, DoubleType, DateType, VarcharType,
156+
Int32Type, UUIDType][i])
157+
for i in range(10)
158+
]
159+
nm_data = build_no_metadata_message(10)
160+
print(f"--- NO_METADATA, 10 cols, 0 rows ({len(nm_data)} bytes) ---")
161+
def nm_fn(data=nm_data, md=result_md):
162+
f = io.BytesIO(data)
163+
msg = fast_cls(2)
164+
msg.recv_results_rows(f, 4, {}, md, None)
165+
t = bench("Cython recv_results_rows", nm_fn, 10000)
166+
results["NO_METADATA"] = t
167+
print()
168+
169+
print("=" * 60)
170+
print("SUMMARY (copy these numbers for A/B comparison)")
171+
print("=" * 60)
172+
for k, v in results.items():
173+
print(f" {k:30s} {v:9.0f} ns")
174+
175+
176+
if __name__ == '__main__':
177+
main()

bench_ab_default.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
#!/usr/bin/env python3
2+
"""A/B comparison: master vs PR for Default(DCAware) regression.
3+
4+
Stashes PR changes, benchmarks master, restores PR, benchmarks PR,
5+
all in one script using subprocess to avoid module caching.
6+
"""
7+
import subprocess
8+
import sys
9+
import json
10+
11+
BENCH_CODE = '''
12+
import time, uuid, statistics
13+
from unittest.mock import Mock
14+
from cassandra.policies import DCAwareRoundRobinPolicy, DefaultLoadBalancingPolicy, SimpleConvictionPolicy
15+
from cassandra.pool import Host
16+
17+
class EP:
18+
def __init__(self, a):
19+
self.address = str(a)
20+
self._port = 9042
21+
def resolve(self):
22+
return (self.address, self._port)
23+
def __repr__(self):
24+
return f"{self.address}:{self._port}"
25+
def __hash__(self):
26+
return hash((self.address, self._port))
27+
def __eq__(self, o):
28+
return isinstance(o, EP) and self.address == o.address
29+
30+
hosts = []
31+
for dc in range(5):
32+
for rack in range(3):
33+
for node in range(3):
34+
h = Host(EP(f"10.{dc}.{rack}.{node}"), SimpleConvictionPolicy, host_id=uuid.uuid4())
35+
h.set_location_info(f"dc{dc}", f"rack{rack}")
36+
h.set_up()
37+
hosts.append(h)
38+
39+
cluster = Mock()
40+
cluster.metadata = Mock()
41+
cluster.metadata.get_host = Mock(return_value=None)
42+
43+
child = DCAwareRoundRobinPolicy(local_dc="dc0", used_hosts_per_remote_dc=1)
44+
policy = DefaultLoadBalancingPolicy(child)
45+
policy.populate(cluster, hosts)
46+
47+
q = Mock()
48+
q.keyspace = None
49+
q.target_host = None
50+
51+
N, ITERS = 100_000, 7
52+
times = []
53+
for _ in range(ITERS):
54+
s = time.perf_counter_ns()
55+
for _ in range(N):
56+
for _ in policy.make_query_plan("ks", q):
57+
pass
58+
times.append((time.perf_counter_ns() - s) / N)
59+
60+
print(f"{statistics.median(times):.0f}")
61+
'''
62+
63+
def run_bench():
64+
result = subprocess.run(
65+
["taskset", "-c", "0", sys.executable, "-c", BENCH_CODE],
66+
capture_output=True, text=True, timeout=120
67+
)
68+
if result.returncode != 0:
69+
print(f"STDERR: {result.stderr}", file=sys.stderr)
70+
raise RuntimeError(f"Benchmark failed: {result.stderr}")
71+
return float(result.stdout.strip())
72+
73+
# Run PR version 3 times
74+
print("Running PR version...")
75+
pr_results = []
76+
for i in range(3):
77+
ns = run_bench()
78+
pr_results.append(ns)
79+
print(f" Run {i+1}: {ns:.0f} ns/op")
80+
81+
# Switch to master
82+
subprocess.run(["git", "stash"], capture_output=True)
83+
subprocess.run(["git", "checkout", "origin/master", "--", "cassandra/policies.py"], capture_output=True)
84+
85+
print("Running master version...")
86+
master_results = []
87+
for i in range(3):
88+
ns = run_bench()
89+
master_results.append(ns)
90+
print(f" Run {i+1}: {ns:.0f} ns/op")
91+
92+
# Restore PR
93+
subprocess.run(["git", "checkout", "pr-651", "--", "cassandra/policies.py"], capture_output=True)
94+
subprocess.run(["git", "stash", "pop"], capture_output=True, check=False)
95+
96+
pr_med = statistics.median(pr_results)
97+
master_med = statistics.median(master_results)
98+
print(f"\nDefault(DCAware) - master: {master_med:.0f} ns/op")
99+
print(f"Default(DCAware) - PR: {pr_med:.0f} ns/op")
100+
print(f"Difference: {pr_med - master_med:+.0f} ns/op ({(pr_med/master_med - 1)*100:+.1f}%)")
101+
102+
import statistics

0 commit comments

Comments
 (0)