Skip to content

Commit fa0a4a5

Browse files
committed
fix: address review feedback for Cython serializers
- Support __index__ protocol in int32 serialization (_coerce_int helper) - Add empty-list guard to deserializers.obj_array() for consistency - Move test file to tests/unit/cython/test_serializers.py, use shared cythontest - Skip tuple() copy when value is already list or tuple (perf optimization) - Add test_index_protocol and test_vector_index_protocol tests - Update docstrings for obj_array() in both serializers and deserializers
1 parent 6b825ba commit fa0a4a5

3 files changed

Lines changed: 70 additions & 26 deletions

File tree

cassandra/deserializers.pyx

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,15 @@ cpdef Deserializer find_deserializer(cqltype):
481481

482482

483483
def obj_array(list objs):
484-
"""Create a (Cython) array of objects given a list of objects"""
484+
"""Create a (Cython) array of objects given a list of objects.
485+
486+
Returns the plain list for empty input since ``cython_array`` does
487+
not support zero-length shapes. Callers that use
488+
``cdef Deserializer[::1]`` typed memoryviews must guard against
489+
empty input before assignment.
490+
"""
491+
if not objs:
492+
return objs
485493
cdef object[:] arr
486494
cdef Py_ssize_t i
487495
arr = cython_array(shape=(len(objs),), itemsize=sizeof(void *), format="O")

cassandra/serializers.pyx

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ from cpython.bytes cimport PyBytes_FromStringAndSize, PyBytes_AS_STRING
3434
from cython.view cimport array as cython_array
3535

3636
from cassandra import cqltypes
37+
from operator import index as _operator_index
3738

3839
cdef bint is_little_endian
3940
from cassandra.util import is_little_endian
@@ -69,13 +70,26 @@ cdef inline void _check_int32_range(object value) except *:
6970
struct.pack('>i', ...): [-2147483648, 2147483647]. The check must
7071
be done on the Python int *before* the C-level <int32_t> cast,
7172
which would silently truncate.
73+
74+
The value should already have been coerced via ``_coerce_int()``
75+
(i.e. the ``__index__`` protocol) before being passed here.
7276
"""
7377
if value > 2147483647 or value < -2147483648:
7478
raise OverflowError(
7579
"Value %r out of range for int32 "
7680
"(must be between -2147483648 and 2147483647)" % (value,))
7781

7882

83+
cdef inline object _coerce_int(object value):
84+
"""Coerce *value* to a Python ``int`` via the ``__index__`` protocol.
85+
86+
This matches ``struct.pack('>i', value)`` semantics, which accepts any
87+
object implementing ``__index__`` (e.g. numpy integer scalars). Raises
88+
``TypeError`` for objects that do not support the protocol.
89+
"""
90+
return _operator_index(value)
91+
92+
7993
# ---------------------------------------------------------------------------
8094
# Base class
8195
# ---------------------------------------------------------------------------
@@ -141,6 +155,7 @@ cdef class SerInt32Type(Serializer):
141155
"""Serialize a Python int to 4-byte big-endian signed int32."""
142156

143157
cpdef bytes serialize(self, object value, int protocol_version):
158+
value = _coerce_int(value)
144159
_check_int32_range(value)
145160
cdef int32_t val = <int32_t>value
146161
cdef char out[4]
@@ -204,10 +219,12 @@ cdef class SerVectorType(Serializer):
204219
self.type_code = 0
205220

206221
cpdef bytes serialize(self, object value, int protocol_version):
207-
# Normalize to tuple so indexing works for any iterable with __len__.
222+
# Normalize to tuple/list so indexing works for any iterable.
208223
# The Python VectorType.serialize() only requires len() + iteration,
209-
# so we must accept the same inputs.
210-
value = tuple(value)
224+
# so we must accept the same inputs. Avoid a copy if value is
225+
# already a list or tuple (common fast path for embeddings).
226+
if not isinstance(value, (list, tuple)):
227+
value = tuple(value)
211228
cdef Py_ssize_t v_length = len(value)
212229
if v_length != self.vector_size:
213230
raise ValueError(
@@ -315,7 +332,7 @@ cdef class SerVectorType(Serializer):
315332
cdef char *dst
316333

317334
for i in range(self.vector_size):
318-
item = values[i]
335+
item = _coerce_int(values[i])
319336
_check_int32_range(item)
320337
val = <int32_t>item
321338
src = <char *>&val
@@ -403,7 +420,9 @@ def obj_array(list objs):
403420
404421
Mirrors ``deserializers.obj_array()`` so both sides share the same
405422
typed-memoryview convention. Returns the plain list for empty input
406-
since ``cython_array`` does not support zero-length shapes.
423+
since ``cython_array`` does not support zero-length shapes. Callers
424+
that use ``cdef Serializer[::1]`` typed memoryviews must guard
425+
against empty input before assignment.
407426
"""
408427
if not objs:
409428
return objs
Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import struct
2424
import unittest
2525

26+
from tests.unit.cython.utils import cythontest
27+
2628
from cassandra.cython_deps import HAVE_CYTHON
2729

2830
try:
@@ -40,27 +42,20 @@
4042
BooleanType,
4143
)
4244

43-
# Import serializers only if Cython is available (compiled .so present)
45+
# Import serializers only if Cython is available (compiled .so present).
46+
# When VERIFY_CYTHON is set (CI mode), let ImportError propagate so build
47+
# failures are not silently swallowed.
4448
if HAVE_CYTHON or VERIFY_CYTHON:
45-
try:
46-
from cassandra.serializers import (
47-
Serializer,
48-
SerFloatType,
49-
SerDoubleType,
50-
SerInt32Type,
51-
SerVectorType,
52-
GenericSerializer,
53-
find_serializer,
54-
make_serializers,
55-
)
56-
except ImportError:
57-
# .so not built — fall back so @cythontest skips gracefully
58-
HAVE_CYTHON = False
59-
VERIFY_CYTHON = False
60-
61-
cythontest = unittest.skipUnless(
62-
HAVE_CYTHON or VERIFY_CYTHON, "Cython is not available"
63-
)
49+
from cassandra.serializers import (
50+
Serializer,
51+
SerFloatType,
52+
SerDoubleType,
53+
SerInt32Type,
54+
SerVectorType,
55+
GenericSerializer,
56+
find_serializer,
57+
make_serializers,
58+
)
6459

6560
# Protocol version used in tests (value doesn't affect scalar serialization)
6661
PROTO = 4
@@ -254,6 +249,17 @@ def test_type_error_none(self):
254249
with self.assertRaises(TypeError):
255250
self.ser.serialize(None, PROTO)
256251

252+
def test_index_protocol(self):
253+
"""Objects implementing __index__ should be accepted, like struct.pack."""
254+
255+
class MyInt:
256+
def __index__(self):
257+
return 42
258+
259+
cython_bytes = self.ser.serialize(MyInt(), PROTO)
260+
python_bytes = Int32Type.serialize(42, PROTO)
261+
self.assertEqual(cython_bytes, python_bytes)
262+
257263

258264
# ---------------------------------------------------------------------------
259265
# VectorType serializer equivalence tests
@@ -364,6 +370,17 @@ def test_element_overflow_negative(self):
364370
with self.assertRaises((OverflowError, struct.error)):
365371
self.ser.serialize([1, -2147483649, 3], PROTO)
366372

373+
def test_vector_index_protocol(self):
374+
"""Objects implementing __index__ in vector elements should be accepted."""
375+
376+
class MyInt:
377+
def __index__(self):
378+
return 42
379+
380+
cython_bytes = self.ser.serialize([MyInt(), MyInt(), MyInt()], PROTO)
381+
python_bytes = self.vec_type.serialize([42, 42, 42], PROTO)
382+
self.assertEqual(cython_bytes, python_bytes)
383+
367384

368385
@cythontest
369386
class TestSerVectorTypeGenericFallback(unittest.TestCase):

0 commit comments

Comments
 (0)