Skip to content

Commit 91cd1ef

Browse files
committed
fix: address review feedback for Cython serializers
- Guard find_serializer() against un-parameterized VectorType base class that would crash _is_float_type(None) with TypeError - Fix assertAlmostEqual(inf, inf) test bug (inf - inf = nan fails) - Wrap test serializer import in try/except for graceful skip when .so is missing under VERIFY_CYTHON - Change make_serializers() to return obj_array (Cython typed memoryview) matching make_deserializers() convention - Eliminate double values[i] indexing in _serialize_float vector loop - Remove unused ctypes import in test_flt_max
1 parent df1fa44 commit 91cd1ef

2 files changed

Lines changed: 60 additions & 21 deletions

File tree

cassandra/serializers.pyx

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ from libc.string cimport memcpy
3131
from libc.math cimport isinf, isnan
3232
from libc.float cimport FLT_MAX
3333
from cpython.bytes cimport PyBytes_FromStringAndSize, PyBytes_AS_STRING
34+
from cython.view cimport array as cython_array
3435

3536
from cassandra import cqltypes
3637

@@ -229,13 +230,15 @@ cdef class SerVectorType(Serializer):
229230

230231
cdef object result = PyBytes_FromStringAndSize(NULL, buf_size)
231232
cdef char *buf = PyBytes_AS_STRING(result)
233+
cdef double dval
232234
cdef float val
233235
cdef char *src
234236
cdef char *dst
235237

236238
for i in range(self.vector_size):
237-
_check_float_range(<double>values[i])
238-
val = <float>values[i]
239+
dval = <double>values[i]
240+
_check_float_range(dval)
241+
val = <float>dval
239242
src = <char *>&val
240243
dst = buf + i * 4
241244

@@ -361,9 +364,12 @@ cdef dict _ser_classes = {}
361364
cpdef Serializer find_serializer(cqltype):
362365
"""Find a serializer for a cqltype."""
363366

364-
# For VectorType, always use SerVectorType (it handles generic subtypes internally)
367+
# For VectorType, use SerVectorType only if parameterized (has a valid subtype).
368+
# Un-parameterized VectorType (base class) would crash _is_float_type() etc.
365369
if issubclass(cqltype, cqltypes.VectorType):
366-
return SerVectorType(cqltype)
370+
if getattr(cqltype, 'subtype', None) is not None:
371+
return SerVectorType(cqltype)
372+
return GenericSerializer(cqltype)
367373

368374
# For scalar types with dedicated serializers, look up by name
369375
name = 'Ser' + cqltype.__name__
@@ -376,8 +382,29 @@ cpdef Serializer find_serializer(cqltype):
376382

377383

378384
def make_serializers(cqltypes_list):
379-
"""Create a list of Serializer objects for each given cqltype."""
380-
return [find_serializer(ct) for ct in cqltypes_list]
385+
"""Create a Cython typed array of Serializer objects for each given cqltype.
386+
387+
Returns an ``obj_array`` (Cython typed memoryview) matching the
388+
``make_deserializers()`` convention for O(1) C-level indexed access.
389+
"""
390+
return obj_array([find_serializer(ct) for ct in cqltypes_list])
391+
392+
393+
def obj_array(list objs):
394+
"""Create a (Cython) array of objects given a list of objects.
395+
396+
Mirrors ``deserializers.obj_array()`` so both sides share the same
397+
typed-memoryview convention. Returns the plain list for empty input
398+
since ``cython_array`` does not support zero-length shapes.
399+
"""
400+
if not objs:
401+
return objs
402+
cdef object[:] arr
403+
cdef Py_ssize_t i
404+
arr = cython_array(shape=(len(objs),), itemsize=sizeof(void *), format="O")
405+
for i, obj in enumerate(objs):
406+
arr[i] = obj
407+
return arr
381408

382409

383410
# Build the lookup dict for scalar serializers at module load time

tests/unit/test_serializers.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
implementations, plus correct error behavior for edge cases.
2020
"""
2121

22+
import math
2223
import struct
2324
import unittest
2425

@@ -39,18 +40,23 @@
3940
BooleanType,
4041
)
4142

42-
# Import serializers only if Cython is available
43+
# Import serializers only if Cython is available (compiled .so present)
4344
if HAVE_CYTHON or VERIFY_CYTHON:
44-
from cassandra.serializers import (
45-
Serializer,
46-
SerFloatType,
47-
SerDoubleType,
48-
SerInt32Type,
49-
SerVectorType,
50-
GenericSerializer,
51-
find_serializer,
52-
make_serializers,
53-
)
45+
try:
46+
from cassandra.serializers import (
47+
Serializer,
48+
SerFloatType,
49+
SerDoubleType,
50+
SerInt32Type,
51+
SerVectorType,
52+
GenericSerializer,
53+
find_serializer,
54+
make_serializers,
55+
)
56+
except ImportError:
57+
# .so not built — fall back so @cythontest skips gracefully
58+
HAVE_CYTHON = False
59+
VERIFY_CYTHON = False
5460

5561
cythontest = unittest.skipUnless(
5662
HAVE_CYTHON or VERIFY_CYTHON, "Cython is not available"
@@ -97,8 +103,6 @@ def test_negative_values(self):
97103
self._assert_equiv(val)
98104

99105
def test_flt_max(self):
100-
import ctypes
101-
102106
flt_max = 3.4028234663852886e38
103107
self._assert_equiv(flt_max)
104108
self._assert_equiv(-flt_max)
@@ -410,7 +414,10 @@ def test_float_round_trip(self):
410414
for val in [0.0, 1.0, -1.0, 3.14, float("inf"), float("-inf")]:
411415
serialized = ser.serialize(val, PROTO)
412416
deserialized = FloatType.deserialize(serialized, PROTO)
413-
self.assertAlmostEqual(val, deserialized, places=5)
417+
if math.isinf(val):
418+
self.assertEqual(val, deserialized)
419+
else:
420+
self.assertAlmostEqual(val, deserialized, places=5)
414421

415422
def test_double_round_trip(self):
416423
ser = SerDoubleType(DoubleType)
@@ -481,6 +488,11 @@ def test_generic_delegates_to_python(self):
481488
expected = LongType.serialize(42, PROTO)
482489
self.assertEqual(result, expected)
483490

491+
def test_unparameterized_vector_type_gets_generic(self):
492+
"""Un-parameterized VectorType (base class) should not crash."""
493+
ser = find_serializer(VectorType)
494+
self.assertIsInstance(ser, GenericSerializer)
495+
484496

485497
@cythontest
486498
class TestMakeSerializers(unittest.TestCase):
@@ -497,7 +509,7 @@ def test_basic(self):
497509

498510
def test_empty(self):
499511
serializers = make_serializers([])
500-
self.assertEqual(serializers, [])
512+
self.assertEqual(len(serializers), 0)
501513

502514
def test_with_vector_type(self):
503515
vec_type = _make_vector_type(FloatType, 3)

0 commit comments

Comments
 (0)