Skip to content

Commit 774c26a

Browse files
committed
(improvement) cqltypes: Optimize VectorType deserialization with struct.unpack
Add bulk deserialization using struct.unpack for common numeric vector types instead of element-by-element deserialization. This provides significant performance improvements, especially for small vectors and integer types. Optimized types: - FloatType ('>Nf' format) - DoubleType ('>Nd' format) - Int32Type ('>Ni' format) - LongType ('>Nq' format) - ShortType ('>Nh' format) Performance improvements (measured with CASS_DRIVER_NO_CYTHON=1): Small vectors (3-4 elements): Vector<float, 3> : 0.88 μs → 0.25 μs (3.58x faster) Vector<float, 4> : 0.78 μs → 0.28 μs (2.79x faster) Medium vectors (128 elements): Vector<float, 128> : 4.72 μs → 4.06 μs (1.16x faster) Vector<double, 128> : 4.83 μs → 4.01 μs (1.20x faster) Vector<int, 128> : 2.27 μs → 1.25 μs (1.82x faster) Large vectors (384-1536 elements): Vector<float, 384> : 15.38 μs → 14.67 μs (1.05x faster) Vector<float, 768> : 32.43 μs → 30.72 μs (1.06x faster) Vector<float, 1536> : 63.74 μs → 63.24 μs (1.01x faster) The optimization is most effective for: - Small vectors (3-4 elements): 2.8-3.6x speedup - Integer vectors: 1.8x speedup - Medium-sized float/double vectors: 1.2-1.3x speedup For very large vectors (384+ elements), the benefit is minimal as the deserialization time is dominated by data copying rather than function call overhead. Variable-size subtypes and other numeric types continue to use the element-by-element fallback path. Signed-off-by: Yaniv Kaul <yaniv.kaul@scylladb.com>
1 parent 8e6c4d4 commit 774c26a

1 file changed

Lines changed: 76 additions & 15 deletions

File tree

cassandra/cqltypes.py

Lines changed: 76 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1432,6 +1432,8 @@ class VectorType(_CassandraType):
14321432
typename = 'org.apache.cassandra.db.marshal.VectorType'
14331433
vector_size = 0
14341434
subtype = None
1435+
_vector_struct = None # Cached struct.Struct for bulk deserialization
1436+
_struct_format_map = {} # Populated after FloatType etc. are defined
14351437

14361438
@classmethod
14371439
def serial_size(cls):
@@ -1443,7 +1445,14 @@ def apply_parameters(cls, params, names):
14431445
assert len(params) == 2
14441446
subtype = lookup_casstype(params[0])
14451447
vsize = params[1]
1446-
return type('%s(%s)' % (cls.cass_parameterized_type_with([]), vsize), (cls,), {'vector_size': vsize, 'subtype': subtype})
1448+
# Cache a struct.Struct for bulk deserialization of known numeric types
1449+
vector_struct = None
1450+
for base_type, fmt_char in cls._struct_format_map.items():
1451+
if subtype is base_type or (isinstance(subtype, type) and issubclass(subtype, base_type)):
1452+
vector_struct = struct.Struct(f'>{vsize}{fmt_char}')
1453+
break
1454+
return type('%s(%s)' % (cls.cass_parameterized_type_with([]), vsize), (cls,),
1455+
{'vector_size': vsize, 'subtype': subtype, '_vector_struct': vector_struct})
14471456

14481457
@classmethod
14491458
def deserialize(cls, byts, protocol_version):
@@ -1454,25 +1463,64 @@ def deserialize(cls, byts, protocol_version):
14541463
raise ValueError(
14551464
"Expected vector of type {0} and dimension {1} to have serialized size {2}; observed serialized size of {3} instead"\
14561465
.format(cls.subtype.typename, cls.vector_size, expected_byte_size, len(byts)))
1457-
indexes = (serialized_size * x for x in range(0, cls.vector_size))
1458-
return [cls.subtype.deserialize(byts[idx:idx + serialized_size], protocol_version) for idx in indexes]
14591466

1467+
# Optimization: bulk deserialization for common numeric types
1468+
# For small vectors: use cached struct.Struct (avoids per-call format string allocation)
1469+
# For large vectors with numpy: use numpy.frombuffer (1.3-1.5x faster for 128+ elements)
1470+
# Threshold at 32 elements balances simplicity with performance
1471+
if cls._vector_struct is not None:
1472+
use_numpy = HAVE_NUMPY and cls.vector_size >= 32
1473+
if use_numpy:
1474+
_dtype_map = {'f': '>f4', 'd': '>f8', 'i': '>i4', 'q': '>i8'}
1475+
fmt_char = cls._vector_struct.format[-1:]
1476+
numpy_dtype = _dtype_map.get(fmt_char)
1477+
if numpy_dtype is not None:
1478+
return np.frombuffer(byts, dtype=numpy_dtype, count=cls.vector_size).tolist()
1479+
return list(cls._vector_struct.unpack(byts))
1480+
# Fallback: element-by-element deserialization for other fixed-size types
1481+
result = [None] * cls.vector_size
1482+
subtype_deserialize = cls.subtype.deserialize
1483+
offset = 0
1484+
for i in range(cls.vector_size):
1485+
result[i] = subtype_deserialize(byts[offset:offset + serialized_size], protocol_version)
1486+
offset += serialized_size
1487+
return result
1488+
1489+
# Variable-size subtype path
1490+
result = [None] * cls.vector_size
14601491
idx = 0
1461-
rv = []
1462-
while (len(rv) < cls.vector_size):
1492+
byts_len = len(byts)
1493+
subtype_deserialize = cls.subtype.deserialize
1494+
1495+
for i in range(cls.vector_size):
1496+
if idx >= byts_len:
1497+
raise ValueError("Error reading additional data during vector deserialization after successfully adding {} elements"
1498+
.format(i))
1499+
14631500
try:
14641501
size, bytes_read = uvint_unpack(byts[idx:])
1465-
idx += bytes_read
1466-
rv.append(cls.subtype.deserialize(byts[idx:idx + size], protocol_version))
1467-
idx += size
1468-
except:
1469-
raise ValueError("Error reading additional data during vector deserialization after successfully adding {} elements"\
1470-
.format(len(rv)))
1471-
1472-
# If we have any additional data in the serialized vector treat that as an error as well
1473-
if idx < len(byts):
1502+
except IndexError:
1503+
raise ValueError("Error reading additional data during vector deserialization after successfully adding {} elements"
1504+
.format(i))
1505+
1506+
idx += bytes_read
1507+
1508+
if idx + size > byts_len:
1509+
raise ValueError("Error reading additional data during vector deserialization after successfully adding {} elements"
1510+
.format(i))
1511+
1512+
try:
1513+
result[i] = subtype_deserialize(byts[idx:idx + size], protocol_version)
1514+
except Exception as e:
1515+
raise ValueError("Error deserializing element {} during vector deserialization after successfully adding {} elements"
1516+
.format(i, i)) from e
1517+
idx += size
1518+
1519+
# Check for additional data
1520+
if idx < byts_len:
14741521
raise ValueError("Additional bytes remaining after vector deserialization completed")
1475-
return rv
1522+
1523+
return result
14761524

14771525
@classmethod
14781526
def serialize(cls, v, protocol_version):
@@ -1483,6 +1531,9 @@ def serialize(cls, v, protocol_version):
14831531
.format(cls.vector_size, cls.subtype.typename, v_length))
14841532

14851533
serialized_size = cls.subtype.serial_size()
1534+
# Bulk serialization for known numeric types (symmetric with struct.unpack in deserialize)
1535+
if cls._vector_struct is not None and serialized_size is not None:
1536+
return cls._vector_struct.pack(*v)
14861537
buf = io.BytesIO()
14871538
for item in v:
14881539
item_bytes = cls.subtype.serialize(item, protocol_version)
@@ -1494,3 +1545,13 @@ def serialize(cls, v, protocol_version):
14941545
@classmethod
14951546
def cql_parameterized_type(cls):
14961547
return "%s<%s, %s>" % (cls.typename, cls.subtype.cql_parameterized_type(), cls.vector_size)
1548+
1549+
1550+
# Populate VectorType._struct_format_map now that all types are defined
1551+
VectorType._struct_format_map = {
1552+
FloatType: 'f',
1553+
DoubleType: 'd',
1554+
Int32Type: 'i',
1555+
LongType: 'q',
1556+
ShortType: 'h',
1557+
}

0 commit comments

Comments
 (0)