Skip to content

Commit 44cdd0d

Browse files
committed
Address CoPilot review: bound caches, add clear API, fix benchmark imports
- Add 256-entry size cap to both _deserializer_cache and _make_deserializers_cache to prevent unbounded growth from non-interned parameterized types in unprepared queries. - Add clear_deserializer_caches() public API so that runtime Des* class overrides (e.g. DesBytesType = DesBytesTypeByteArray for cqlsh) can flush stale cached instances. - Add get_deserializer_cache_sizes() diagnostic helper. - Document override/cache interaction in code comments. - Fix benchmark copyright (DataStax -> ScyllaDB), add pytest.importorskip guards for pytest-benchmark and Cython. - Add 11 unit tests for cache hit/miss, clear, eviction bounds, and size reporting.
1 parent 2a7f478 commit 44cdd0d

3 files changed

Lines changed: 231 additions & 1 deletion

File tree

benchmarks/test_deserializer_cache_benchmark.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright DataStax, Inc.
1+
# Copyright ScyllaDB, Inc.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -16,10 +16,16 @@
1616
Benchmarks for find_deserializer / make_deserializers with and without caching.
1717
1818
Run with: pytest benchmarks/test_deserializer_cache_benchmark.py -v
19+
20+
Requires the ``pytest-benchmark`` plugin and Cython extensions to be built.
21+
Skipped automatically when either dependency is unavailable.
1922
"""
2023

2124
import pytest
2225

26+
pytest.importorskip("pytest_benchmark")
27+
pytest.importorskip("cassandra.deserializers")
28+
2329
from cassandra import cqltypes
2430
from cassandra.deserializers import (
2531
find_deserializer,

cassandra/deserializers.pyx

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,14 @@ cdef class GenericDeserializer(Deserializer):
440440
#--------------------------------------------------------------------------
441441
# Helper utilities
442442

443+
# Maximum number of entries in each deserializer cache. In practice the
444+
# caches are bounded by the number of distinct column-type signatures in
445+
# the schema (typically dozens to low hundreds), but parameterized types
446+
# created via apply_parameters() for unprepared queries are *not*
447+
# interned, so repeated simple queries could accumulate entries. The cap
448+
# prevents unbounded growth in such edge cases.
449+
cdef int _CACHE_MAX_SIZE = 256
450+
443451
# Cache make_deserializers results keyed on the tuple of cqltype objects.
444452
# Using the cqltype objects themselves (rather than id()) as keys ensures
445453
# the dict holds strong references, preventing GC and id() reuse issues
@@ -454,6 +462,8 @@ def make_deserializers(cqltypes):
454462
except KeyError:
455463
pass
456464
result = obj_array([find_deserializer(ct) for ct in cqltypes])
465+
if len(_make_deserializers_cache) >= _CACHE_MAX_SIZE:
466+
_make_deserializers_cache.clear()
457467
_make_deserializers_cache[key] = result
458468
return result
459469

@@ -464,6 +474,11 @@ cdef dict classes = globals()
464474
# repeated class lookups and object creation on every result set.
465475
# Using the object as key (rather than id()) holds a strong reference,
466476
# preventing GC and id() reuse issues with parameterized types.
477+
#
478+
# Note: if a Des* class is overridden at runtime (e.g. DesBytesType =
479+
# DesBytesTypeByteArray for cqlsh), callers must invoke
480+
# clear_deserializer_caches() to flush stale entries so that subsequent
481+
# find_deserializer() calls pick up the new class.
467482
cdef dict _deserializer_cache = {}
468483

469484
cpdef Deserializer find_deserializer(cqltype):
@@ -501,10 +516,29 @@ cpdef Deserializer find_deserializer(cqltype):
501516
cls = GenericDeserializer
502517

503518
cdef Deserializer result = cls(cqltype)
519+
if len(_deserializer_cache) >= _CACHE_MAX_SIZE:
520+
_deserializer_cache.clear()
504521
_deserializer_cache[cqltype] = result
505522
return result
506523

507524

525+
def clear_deserializer_caches():
526+
"""Clear the find_deserializer and make_deserializers caches.
527+
528+
Call this after overriding a Des* class at runtime (e.g.
529+
``deserializers.DesBytesType = deserializers.DesBytesTypeByteArray``)
530+
so that subsequent lookups pick up the new class instead of returning
531+
stale cached instances.
532+
"""
533+
_deserializer_cache.clear()
534+
_make_deserializers_cache.clear()
535+
536+
537+
def get_deserializer_cache_sizes():
538+
"""Return ``(find_cache_size, make_cache_size)`` for diagnostic use."""
539+
return len(_deserializer_cache), len(_make_deserializers_cache)
540+
541+
508542
def obj_array(list objs):
509543
"""Create a (Cython) array of objects given a list of objects"""
510544
cdef object[:] arr
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
# Copyright ScyllaDB, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
Unit tests for the deserializer caches in deserializers.pyx.
17+
18+
Validates cache hit/miss behaviour, bounded eviction, the
19+
clear_deserializer_caches() API (needed after runtime Des* overrides),
20+
and the get_deserializer_cache_sizes() diagnostic helper.
21+
"""
22+
23+
import unittest
24+
25+
from tests.unit.cython.utils import cythontest
26+
27+
try:
28+
from cassandra.deserializers import (
29+
clear_deserializer_caches,
30+
find_deserializer,
31+
get_deserializer_cache_sizes,
32+
make_deserializers,
33+
)
34+
35+
_HAS_DESERIALIZERS = True
36+
except ImportError:
37+
_HAS_DESERIALIZERS = False
38+
39+
from cassandra import cqltypes
40+
41+
42+
# ---------------------------------------------------------------------------
43+
# Tests
44+
# ---------------------------------------------------------------------------
45+
46+
47+
class DeserializerCacheTest(unittest.TestCase):
48+
"""Tests for find_deserializer / make_deserializers caching."""
49+
50+
def setUp(self):
51+
if _HAS_DESERIALIZERS:
52+
clear_deserializer_caches()
53+
54+
def tearDown(self):
55+
if _HAS_DESERIALIZERS:
56+
clear_deserializer_caches()
57+
58+
# -- find_deserializer cache -------------------------------------------
59+
60+
@cythontest
61+
def test_find_cache_hit_same_object(self):
62+
"""Repeated calls for the same cqltype return the same instance."""
63+
d1 = find_deserializer(cqltypes.Int32Type)
64+
d2 = find_deserializer(cqltypes.Int32Type)
65+
self.assertIs(d1, d2)
66+
67+
@cythontest
68+
def test_find_cache_miss_different_types(self):
69+
"""Different cqltypes produce different deserializer instances."""
70+
d_int = find_deserializer(cqltypes.Int32Type)
71+
d_utf = find_deserializer(cqltypes.UTF8Type)
72+
self.assertIsNot(d_int, d_utf)
73+
74+
@cythontest
75+
def test_find_returns_correct_deserializer_class(self):
76+
"""The returned deserializer class name matches the cqltype."""
77+
d = find_deserializer(cqltypes.Int32Type)
78+
self.assertEqual(type(d).__name__, "DesInt32Type")
79+
80+
# -- make_deserializers cache ------------------------------------------
81+
82+
@cythontest
83+
def test_make_cache_hit_same_object(self):
84+
"""Repeated calls with the same type list return the same array."""
85+
types = [cqltypes.Int32Type, cqltypes.UTF8Type]
86+
r1 = make_deserializers(types)
87+
r2 = make_deserializers(types)
88+
self.assertIs(r1, r2)
89+
90+
@cythontest
91+
def test_make_cache_correct_length(self):
92+
"""Returned array has the right number of entries."""
93+
types = [cqltypes.Int32Type, cqltypes.UTF8Type, cqltypes.BooleanType]
94+
result = make_deserializers(types)
95+
self.assertEqual(len(result), 3)
96+
97+
# -- clear_deserializer_caches -----------------------------------------
98+
99+
@cythontest
100+
def test_clear_invalidates_find_cache(self):
101+
"""After clearing, find_deserializer returns a new instance."""
102+
d1 = find_deserializer(cqltypes.Int32Type)
103+
clear_deserializer_caches()
104+
d2 = find_deserializer(cqltypes.Int32Type)
105+
# New instance, but same deserializer class
106+
self.assertIsNot(d1, d2)
107+
self.assertEqual(type(d1).__name__, type(d2).__name__)
108+
109+
@cythontest
110+
def test_clear_invalidates_make_cache(self):
111+
"""After clearing, make_deserializers returns a new array."""
112+
types = [cqltypes.Int32Type, cqltypes.UTF8Type]
113+
r1 = make_deserializers(types)
114+
clear_deserializer_caches()
115+
r2 = make_deserializers(types)
116+
self.assertIsNot(r1, r2)
117+
118+
# -- get_deserializer_cache_sizes --------------------------------------
119+
120+
@cythontest
121+
def test_cache_sizes_empty_after_clear(self):
122+
"""Sizes are (0, 0) immediately after clearing."""
123+
find_size, make_size = get_deserializer_cache_sizes()
124+
self.assertEqual(find_size, 0)
125+
self.assertEqual(make_size, 0)
126+
127+
@cythontest
128+
def test_cache_sizes_increment(self):
129+
"""Sizes reflect the number of cached entries."""
130+
find_deserializer(cqltypes.Int32Type)
131+
find_deserializer(cqltypes.UTF8Type)
132+
make_deserializers([cqltypes.Int32Type, cqltypes.UTF8Type])
133+
134+
find_size, make_size = get_deserializer_cache_sizes()
135+
self.assertEqual(find_size, 2)
136+
self.assertEqual(make_size, 1)
137+
138+
# -- bounded eviction --------------------------------------------------
139+
140+
@cythontest
141+
def test_find_cache_bounded_size(self):
142+
"""find_deserializer cache should not exceed 256 entries."""
143+
# Create 300 distinct cqltype objects via apply_parameters.
144+
# Each ListType.apply_parameters() call creates a fresh class.
145+
inner_types = [
146+
cqltypes.Int32Type,
147+
cqltypes.UTF8Type,
148+
cqltypes.BooleanType,
149+
cqltypes.DoubleType,
150+
cqltypes.LongType,
151+
]
152+
distinct_types = []
153+
for i in range(300):
154+
# Create ListType(inner) — each apply_parameters returns a new
155+
# class object, so these are all distinct cache keys.
156+
inner = inner_types[i % len(inner_types)]
157+
ct = cqltypes.ListType.apply_parameters([inner])
158+
distinct_types.append(ct)
159+
160+
for ct in distinct_types:
161+
find_deserializer(ct)
162+
163+
find_size, _ = get_deserializer_cache_sizes()
164+
self.assertLessEqual(
165+
find_size,
166+
256,
167+
"find_deserializer cache should be bounded to 256, got %d" % find_size,
168+
)
169+
170+
@cythontest
171+
def test_make_cache_bounded_size(self):
172+
"""make_deserializers cache should not exceed 256 entries."""
173+
inner_types = [
174+
cqltypes.Int32Type,
175+
cqltypes.UTF8Type,
176+
cqltypes.BooleanType,
177+
cqltypes.DoubleType,
178+
cqltypes.LongType,
179+
]
180+
for i in range(300):
181+
inner = inner_types[i % len(inner_types)]
182+
ct = cqltypes.ListType.apply_parameters([inner])
183+
make_deserializers([ct])
184+
185+
_, make_size = get_deserializer_cache_sizes()
186+
self.assertLessEqual(
187+
make_size,
188+
256,
189+
"make_deserializers cache should be bounded to 256, got %d" % make_size,
190+
)

0 commit comments

Comments
 (0)