Skip to content

Commit b1b0b3d

Browse files
committed
Add BytesReader to replace BytesIO in decode_message()
Introduce a lightweight BytesReader class that provides the same read() interface as io.BytesIO but operates directly on the input bytes/memoryview without internal buffering overhead. Changes: - Add BytesReader class with __slots__ for memory efficiency - Replace io.BytesIO(body) with BytesReader(body) in decode_message() - BytesReader.read() returns slices directly, converting memoryview to bytes only when necessary for compatibility Benefits: - Eliminates BytesIO's internal buffer allocation and management - Reduces memory overhead for protocol message decoding - Works seamlessly with both bytes and memoryview inputs - Maintains full API compatibility with existing read_* functions The BytesReader is a minimal implementation focused on the read() method needed by the protocol decoder. It avoids the overhead of io.BytesIO's full file-like interface. Signed-off-by: Yaniv Kaul <ykaul@scylladb.com>
1 parent c59934c commit b1b0b3d

1 file changed

Lines changed: 30 additions & 1 deletion

File tree

cassandra/protocol.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,34 @@ class NotSupportedError(Exception):
5353
class InternalError(Exception):
5454
pass
5555

56+
57+
class BytesReader:
58+
"""
59+
Lightweight reader for bytes data without BytesIO overhead.
60+
Provides the same read() interface but operates directly on a
61+
bytes or memoryview object, avoiding internal buffer copies.
62+
"""
63+
__slots__ = ('_data', '_pos', '_size')
64+
65+
def __init__(self, data):
66+
self._data = data
67+
self._pos = 0
68+
self._size = len(data)
69+
70+
def read(self, n=-1):
71+
if n < 0:
72+
result = self._data[self._pos:]
73+
self._pos = self._size
74+
else:
75+
end = self._pos + n
76+
if end > self._size:
77+
raise EOFError("Cannot read past the end of the buffer")
78+
result = self._data[self._pos:end]
79+
self._pos = end
80+
# Return bytes to maintain compatibility with unpack functions
81+
return bytes(result) if isinstance(result, memoryview) else result
82+
83+
5684
ColumnMetadata = namedtuple("ColumnMetadata", ['keyspace_name', 'table_name', 'name', 'type'])
5785

5886
HEADER_DIRECTION_TO_CLIENT = 0x80
@@ -1142,7 +1170,8 @@ def decode_message(cls, protocol_version, protocol_features, user_type_map, stre
11421170
body = decompressor(body)
11431171
flags ^= COMPRESSED_FLAG
11441172

1145-
body = io.BytesIO(body)
1173+
# Use lightweight BytesReader instead of io.BytesIO to avoid buffer copy
1174+
body = BytesReader(body)
11461175
if flags & TRACING_FLAG:
11471176
trace_id = UUID(bytes=body.read(16))
11481177
flags ^= TRACING_FLAG

0 commit comments

Comments
 (0)