Skip to content

Commit 4a7b199

Browse files
committed
Address review: bound caches, add clear API, fix benchmark imports
- Add 256-entry size cap to both _deserializer_cache and _make_deserializers_cache to prevent unbounded growth from non-interned parameterized types in unprepared queries. - Add clear_deserializer_caches() public API so that runtime Des* class overrides (e.g. DesBytesType = DesBytesTypeByteArray for cqlsh) can flush stale cached instances. - Add get_deserializer_cache_sizes() diagnostic helper. - Document override/cache interaction in code comments. - Fix benchmark copyright (DataStax -> ScyllaDB), add pytest.importorskip guards for pytest-benchmark and Cython. - Add 11 unit tests for cache hit/miss, clear, eviction bounds, and size reporting. - Add clear_deserializer_caches() calls to integration test for DesBytesType override.
1 parent 00813ef commit 4a7b199

4 files changed

Lines changed: 240 additions & 1 deletion

File tree

benchmarks/test_deserializer_cache_benchmark.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,16 @@
1616
Benchmarks for find_deserializer / make_deserializers with and without caching.
1717
1818
Run with: pytest benchmarks/test_deserializer_cache_benchmark.py -v
19+
20+
Requires the ``pytest-benchmark`` plugin and Cython extensions to be built.
21+
Skipped automatically when either dependency is unavailable.
1922
"""
2023

2124
import pytest
2225

26+
pytest.importorskip("pytest_benchmark")
27+
pytest.importorskip("cassandra.deserializers")
28+
2329
from cassandra import cqltypes
2430
from cassandra.deserializers import (
2531
find_deserializer,

cassandra/deserializers.pyx

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,20 +440,33 @@ cdef class GenericDeserializer(Deserializer):
440440
#--------------------------------------------------------------------------
441441
# Helper utilities
442442

443+
# Maximum number of entries in each deserializer cache. In practice the
444+
# caches are bounded by the number of distinct column-type signatures in
445+
# the schema (typically dozens to low hundreds), but parameterized types
446+
# created via apply_parameters() for unprepared queries are *not*
447+
# interned, so repeated simple queries could accumulate entries. The cap
448+
# prevents unbounded growth in such edge cases.
449+
cdef int _CACHE_MAX_SIZE = 256
450+
443451
# Cache make_deserializers results keyed on the tuple of cqltype objects.
444452
# Using the cqltype objects themselves (rather than id()) as keys ensures
445453
# the dict holds strong references, preventing GC and id() reuse issues
446454
# with non-singleton parameterized types.
447455
cdef dict _make_deserializers_cache = {}
448456

449457
def make_deserializers(cqltypes):
450-
"""Create an array of Deserializers for each given cqltype in cqltypes"""
458+
"""Create an array of Deserializers for each given cqltype in cqltypes.
459+
460+
The returned array may be a cached object shared across callers.
461+
Callers must not modify the returned array."""
451462
cdef tuple key = tuple(cqltypes)
452463
try:
453464
return _make_deserializers_cache[key]
454465
except KeyError:
455466
pass
456467
result = obj_array([find_deserializer(ct) for ct in cqltypes])
468+
if len(_make_deserializers_cache) >= _CACHE_MAX_SIZE:
469+
_make_deserializers_cache.clear()
457470
_make_deserializers_cache[key] = result
458471
return result
459472

@@ -464,6 +477,11 @@ cdef dict classes = globals()
464477
# repeated class lookups and object creation on every result set.
465478
# Using the object as key (rather than id()) holds a strong reference,
466479
# preventing GC and id() reuse issues with parameterized types.
480+
#
481+
# Note: if a Des* class is overridden at runtime (e.g. DesBytesType =
482+
# DesBytesTypeByteArray for cqlsh), callers must invoke
483+
# clear_deserializer_caches() to flush stale entries so that subsequent
484+
# find_deserializer() calls pick up the new class.
467485
cdef dict _deserializer_cache = {}
468486

469487
cpdef Deserializer find_deserializer(cqltype):
@@ -501,10 +519,29 @@ cpdef Deserializer find_deserializer(cqltype):
501519
cls = GenericDeserializer
502520

503521
cdef Deserializer result = cls(cqltype)
522+
if len(_deserializer_cache) >= _CACHE_MAX_SIZE:
523+
_deserializer_cache.clear()
504524
_deserializer_cache[cqltype] = result
505525
return result
506526

507527

528+
def clear_deserializer_caches():
529+
"""Clear the find_deserializer and make_deserializers caches.
530+
531+
Call this after overriding a Des* class at runtime (e.g.
532+
``deserializers.DesBytesType = deserializers.DesBytesTypeByteArray``)
533+
so that subsequent lookups pick up the new class instead of returning
534+
stale cached instances.
535+
"""
536+
_deserializer_cache.clear()
537+
_make_deserializers_cache.clear()
538+
539+
540+
def get_deserializer_cache_sizes():
541+
"""Return ``(find_cache_size, make_cache_size)`` for diagnostic use."""
542+
return len(_deserializer_cache), len(_make_deserializers_cache)
543+
544+
508545
def obj_array(list objs):
509546
"""Create a (Cython) array of objects given a list of objects"""
510547
cdef object[:] arr

tests/integration/standard/test_types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def test_des_bytes_type_array(self):
108108

109109
original = cassandra.deserializers.DesBytesType
110110
cassandra.deserializers.DesBytesType = cassandra.deserializers.DesBytesTypeByteArray
111+
cassandra.deserializers.clear_deserializer_caches()
111112
s = self.session
112113

113114
s.execute("CREATE TABLE blobbytes2 (a ascii PRIMARY KEY, b blob)")
@@ -121,6 +122,7 @@ def test_des_bytes_type_array(self):
121122
finally:
122123
if original is not None:
123124
cassandra.deserializers.DesBytesType=original
125+
cassandra.deserializers.clear_deserializer_caches()
124126

125127
def test_can_insert_primitive_datatypes(self):
126128
"""
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
# Copyright ScyllaDB, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
Unit tests for the deserializer caches in deserializers.pyx.
17+
18+
Validates cache hit/miss behaviour, bounded eviction, the
19+
clear_deserializer_caches() API (needed after runtime Des* overrides),
20+
and the get_deserializer_cache_sizes() diagnostic helper.
21+
"""
22+
23+
import unittest
24+
25+
from tests.unit.cython.utils import cythontest
26+
27+
try:
28+
from cassandra.deserializers import (
29+
clear_deserializer_caches,
30+
find_deserializer,
31+
get_deserializer_cache_sizes,
32+
make_deserializers,
33+
)
34+
35+
_HAS_DESERIALIZERS = True
36+
except ImportError:
37+
_HAS_DESERIALIZERS = False
38+
39+
from cassandra import cqltypes
40+
41+
42+
# ---------------------------------------------------------------------------
43+
# Tests
44+
# ---------------------------------------------------------------------------
45+
46+
47+
class DeserializerCacheTest(unittest.TestCase):
48+
"""Tests for find_deserializer / make_deserializers caching."""
49+
50+
def setUp(self):
51+
if _HAS_DESERIALIZERS:
52+
clear_deserializer_caches()
53+
54+
def tearDown(self):
55+
if _HAS_DESERIALIZERS:
56+
clear_deserializer_caches()
57+
58+
# -- find_deserializer cache -------------------------------------------
59+
60+
@cythontest
61+
def test_find_cache_hit_same_object(self):
62+
"""Repeated calls for the same cqltype return the same instance."""
63+
d1 = find_deserializer(cqltypes.Int32Type)
64+
d2 = find_deserializer(cqltypes.Int32Type)
65+
self.assertIs(d1, d2)
66+
67+
@cythontest
68+
def test_find_cache_miss_different_types(self):
69+
"""Different cqltypes produce different deserializer instances."""
70+
d_int = find_deserializer(cqltypes.Int32Type)
71+
d_utf = find_deserializer(cqltypes.UTF8Type)
72+
self.assertIsNot(d_int, d_utf)
73+
74+
@cythontest
75+
def test_find_returns_correct_deserializer_class(self):
76+
"""The returned deserializer class name matches the cqltype."""
77+
d = find_deserializer(cqltypes.Int32Type)
78+
self.assertEqual(type(d).__name__, "DesInt32Type")
79+
80+
# -- make_deserializers cache ------------------------------------------
81+
82+
@cythontest
83+
def test_make_cache_hit_same_object(self):
84+
"""Repeated calls with the same type list return the same array."""
85+
types = [cqltypes.Int32Type, cqltypes.UTF8Type]
86+
r1 = make_deserializers(types)
87+
r2 = make_deserializers(types)
88+
self.assertIs(r1, r2)
89+
90+
@cythontest
91+
def test_make_cache_correct_length(self):
92+
"""Returned array has the right number of entries."""
93+
types = [cqltypes.Int32Type, cqltypes.UTF8Type, cqltypes.BooleanType]
94+
result = make_deserializers(types)
95+
self.assertEqual(len(result), 3)
96+
97+
# -- clear_deserializer_caches -----------------------------------------
98+
99+
@cythontest
100+
def test_clear_invalidates_find_cache(self):
101+
"""After clearing, find_deserializer returns a new instance."""
102+
d1 = find_deserializer(cqltypes.Int32Type)
103+
clear_deserializer_caches()
104+
d2 = find_deserializer(cqltypes.Int32Type)
105+
# New instance, but same deserializer class
106+
self.assertIsNot(d1, d2)
107+
self.assertEqual(type(d1).__name__, type(d2).__name__)
108+
109+
@cythontest
110+
def test_clear_invalidates_make_cache(self):
111+
"""After clearing, make_deserializers returns a new array."""
112+
types = [cqltypes.Int32Type, cqltypes.UTF8Type]
113+
r1 = make_deserializers(types)
114+
clear_deserializer_caches()
115+
r2 = make_deserializers(types)
116+
self.assertIsNot(r1, r2)
117+
118+
# -- get_deserializer_cache_sizes --------------------------------------
119+
120+
@cythontest
121+
def test_cache_sizes_empty_after_clear(self):
122+
"""Sizes are (0, 0) immediately after clearing."""
123+
find_size, make_size = get_deserializer_cache_sizes()
124+
self.assertEqual(find_size, 0)
125+
self.assertEqual(make_size, 0)
126+
127+
@cythontest
128+
def test_cache_sizes_increment(self):
129+
"""Sizes reflect the number of cached entries."""
130+
find_deserializer(cqltypes.Int32Type)
131+
find_deserializer(cqltypes.UTF8Type)
132+
make_deserializers([cqltypes.Int32Type, cqltypes.UTF8Type])
133+
134+
find_size, make_size = get_deserializer_cache_sizes()
135+
self.assertEqual(find_size, 2)
136+
self.assertEqual(make_size, 1)
137+
138+
# -- bounded eviction --------------------------------------------------
139+
140+
@cythontest
141+
def test_find_cache_bounded_size(self):
142+
"""find_deserializer cache should not exceed 256 entries."""
143+
# Create 300 distinct cqltype objects via apply_parameters.
144+
# Each ListType.apply_parameters() call creates a fresh class
145+
# object (no interning), so all 300 are distinct cache keys
146+
# even though only 5 inner types are cycled through.
147+
inner_types = [
148+
cqltypes.Int32Type,
149+
cqltypes.UTF8Type,
150+
cqltypes.BooleanType,
151+
cqltypes.DoubleType,
152+
cqltypes.LongType,
153+
]
154+
distinct_types = []
155+
for i in range(300):
156+
# Create ListType(inner) — each apply_parameters returns a new
157+
# class object, so these are all distinct cache keys.
158+
inner = inner_types[i % len(inner_types)]
159+
ct = cqltypes.ListType.apply_parameters([inner])
160+
distinct_types.append(ct)
161+
162+
for ct in distinct_types:
163+
find_deserializer(ct)
164+
165+
find_size, _ = get_deserializer_cache_sizes()
166+
self.assertLessEqual(
167+
find_size,
168+
256,
169+
"find_deserializer cache should be bounded to 256, got %d" % find_size,
170+
)
171+
172+
@cythontest
173+
def test_make_cache_bounded_size(self):
174+
"""make_deserializers cache should not exceed 256 entries."""
175+
# Each apply_parameters() call returns a new class object (no
176+
# interning), so all 300 iterations produce distinct cache keys.
177+
inner_types = [
178+
cqltypes.Int32Type,
179+
cqltypes.UTF8Type,
180+
cqltypes.BooleanType,
181+
cqltypes.DoubleType,
182+
cqltypes.LongType,
183+
]
184+
for i in range(300):
185+
inner = inner_types[i % len(inner_types)]
186+
ct = cqltypes.ListType.apply_parameters([inner])
187+
make_deserializers([ct])
188+
189+
_, make_size = get_deserializer_cache_sizes()
190+
self.assertLessEqual(
191+
make_size,
192+
256,
193+
"make_deserializers cache should be bounded to 256, got %d" % make_size,
194+
)

0 commit comments

Comments
 (0)