Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 92 additions & 19 deletions cassandra/cqltypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@
varint_pack, varint_unpack, point_be, point_le,
vints_pack, vints_unpack, uvint_unpack, uvint_pack)
from cassandra import util
from cassandra.cython_deps import HAVE_NUMPY

if HAVE_NUMPY:
import numpy as np

Comment thread
mykaul marked this conversation as resolved.
_little_endian_flag = 1 # we always serialize LE
import ipaddress
Expand Down Expand Up @@ -1432,47 +1436,100 @@ class VectorType(_CassandraType):
typename = 'org.apache.cassandra.db.marshal.VectorType'
vector_size = 0
subtype = None
_vector_struct = None # Cached struct.Struct for bulk deserialization
_struct_format_map = {} # Populated after FloatType etc. are defined
_numpy_dtype = None # Cached numpy dtype string for large vector deserialization
_subtype_serial_size = None # Cached subtype.serial_size() (computed once in apply_parameters)
_serial_size = None # Cached serial_size() for the full vector (subtype_serial_size * vector_size)

@classmethod
def serial_size(cls):
serialized_size = cls.subtype.serial_size()
return cls.vector_size * serialized_size if serialized_size is not None else None
return cls._serial_size


@classmethod
def apply_parameters(cls, params, names):
assert len(params) == 2
subtype = lookup_casstype(params[0])
vsize = params[1]
return type('%s(%s)' % (cls.cass_parameterized_type_with([]), vsize), (cls,), {'vector_size': vsize, 'subtype': subtype})
# Cache a struct.Struct for bulk deserialization of known numeric types
vector_struct = None
numpy_dtype = None
for base_type, fmt_char in cls._struct_format_map.items():
if subtype is base_type or (isinstance(subtype, type) and issubclass(subtype, base_type)):
vector_struct = struct.Struct(f'>{vsize}{fmt_char}')
numpy_dtype = cls._numpy_dtype_map.get(fmt_char)
break
# Cache subtype serial_size and full vector serial_size to avoid
# repeated method dispatch in serialize/deserialize hot paths.
subtype_ss = subtype.serial_size()
vec_ss = vsize * subtype_ss if subtype_ss is not None else None
return type('%s(%s)' % (cls.cass_parameterized_type_with([]), vsize), (cls,),
{'vector_size': vsize, 'subtype': subtype, '_vector_struct': vector_struct,
'_numpy_dtype': numpy_dtype, '_subtype_serial_size': subtype_ss, '_serial_size': vec_ss})

@classmethod
def deserialize(cls, byts, protocol_version):
serialized_size = cls.subtype.serial_size()
serialized_size = cls._subtype_serial_size
if serialized_size is not None:
expected_byte_size = serialized_size * cls.vector_size
if len(byts) != expected_byte_size:
raise ValueError(
"Expected vector of type {0} and dimension {1} to have serialized size {2}; observed serialized size of {3} instead"\
.format(cls.subtype.typename, cls.vector_size, expected_byte_size, len(byts)))
indexes = (serialized_size * x for x in range(0, cls.vector_size))
return [cls.subtype.deserialize(byts[idx:idx + serialized_size], protocol_version) for idx in indexes]

# Optimization: bulk deserialization for common numeric types
# For small vectors: use cached struct.Struct (avoids per-call format string allocation)
# For large vectors with numpy: use numpy.frombuffer (1.3-1.5x faster for 128+ elements)
# Threshold at 32 elements balances simplicity with performance
if cls._vector_struct is not None:
if HAVE_NUMPY and cls.vector_size >= 32 and cls._numpy_dtype is not None:
return np.frombuffer(byts, dtype=cls._numpy_dtype, count=cls.vector_size).tolist()
return list(cls._vector_struct.unpack(byts))
Comment thread
mykaul marked this conversation as resolved.
# Fallback: element-by-element deserialization for other fixed-size types
result = [None] * cls.vector_size
subtype_deserialize = cls.subtype.deserialize
offset = 0
for i in range(cls.vector_size):
result[i] = subtype_deserialize(byts[offset:offset + serialized_size], protocol_version)
offset += serialized_size
return result

# Variable-size subtype path
result = [None] * cls.vector_size
idx = 0
rv = []
while (len(rv) < cls.vector_size):
byts_len = len(byts)
subtype_deserialize = cls.subtype.deserialize

for i in range(cls.vector_size):
if idx >= byts_len:
raise ValueError("Error reading additional data during vector deserialization after successfully adding {} elements"
.format(i))

try:
size, bytes_read = uvint_unpack(byts[idx:])
idx += bytes_read
rv.append(cls.subtype.deserialize(byts[idx:idx + size], protocol_version))
idx += size
except:
raise ValueError("Error reading additional data during vector deserialization after successfully adding {} elements"\
.format(len(rv)))

# If we have any additional data in the serialized vector treat that as an error as well
if idx < len(byts):
except IndexError:
raise ValueError("Error reading additional data during vector deserialization after successfully adding {} elements"
.format(i))
Comment thread
mykaul marked this conversation as resolved.

idx += bytes_read

if idx + size > byts_len:
raise ValueError("Error reading additional data during vector deserialization after successfully adding {} elements"
.format(i))

try:
result[i] = subtype_deserialize(byts[idx:idx + size], protocol_version)
except Exception as e:
raise ValueError("Error deserializing element {} during vector deserialization after successfully adding {} elements"
.format(i, i)) from e
idx += size

# Check for additional data
if idx < byts_len:
raise ValueError("Additional bytes remaining after vector deserialization completed")
return rv

return result

@classmethod
def serialize(cls, v, protocol_version):
Expand All @@ -1482,7 +1539,10 @@ def serialize(cls, v, protocol_version):
"Expected sequence of size {0} for vector of type {1} and dimension {0}, observed sequence of length {2}"\
.format(cls.vector_size, cls.subtype.typename, v_length))

serialized_size = cls.subtype.serial_size()
serialized_size = cls._subtype_serial_size
# Bulk serialization for known numeric types (symmetric with struct.unpack in deserialize)
if cls._vector_struct is not None and serialized_size is not None:
return cls._vector_struct.pack(*v)
buf = io.BytesIO()
for item in v:
item_bytes = cls.subtype.serialize(item, protocol_version)
Expand All @@ -1494,3 +1554,16 @@ def serialize(cls, v, protocol_version):
@classmethod
def cql_parameterized_type(cls):
return "%s<%s, %s>" % (cls.typename, cls.subtype.cql_parameterized_type(), cls.vector_size)


# Populate VectorType._struct_format_map now that all types are defined
VectorType._struct_format_map = {
FloatType: 'f',
DoubleType: 'd',
Int32Type: 'i',
LongType: 'q',
ShortType: 'h',
}

# Map struct format chars to numpy dtype strings for large vector deserialization
VectorType._numpy_dtype_map = {'f': '>f4', 'd': '>f8', 'i': '>i4', 'q': '>i8', 'h': '>i2'}
Loading