Skip to content

Commit d062ec2

Browse files
committed
Add __slots__ to protocol message classes for memory optimization
Add __slots__ to _MessageType base class and all protocol message classes to reduce memory overhead. Each message instance saves approximately 280-300 bytes by eliminating the per-instance __dict__. Changes: - _MessageType: Added __slots__ with 'custom_payload' and 'tracing' - _MessageType: Added __init__ to initialize attributes with proper defaults - _DecodableMessageType: Removed duplicate 'custom_payload' from __slots__ - All message subclasses: Added super().__init__() calls for proper initialization Key attributes must be in __slots__ because: - custom_payload: Accessed by encode_message() for ALL message types (line 1127) - tracing: Set on message instances in cluster.py (line 2972) Without these in __slots__, attempting to set them raises AttributeError, causing connection failures with: 'OptionsMessage' object has no attribute 'custom_payload' Message classes covered: - Outgoing (_MessageType): StartupMessage, OptionsMessage, QueryMessage, ExecuteMessage, PrepareMessage, BatchMessage, RegisterMessage, etc. - Incoming (_DecodableMessageType): ResultMessage, EventMessage, AuthenticateMessage, SupportedMessage, etc. Signed-off-by: Yaniv Kaul <yaniv.kaul@scylladb.com>
1 parent da64595 commit d062ec2

1 file changed

Lines changed: 97 additions & 43 deletions

File tree

cassandra/protocol.py

Lines changed: 97 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,11 @@ def __init__(cls, name, bases, dct):
8585

8686

8787
class _MessageType(object, metaclass=_RegisterMessageType):
88+
__slots__ = ('custom_payload', 'tracing')
8889

89-
tracing = False
90-
custom_payload = None
91-
warnings = None
90+
def __init__(self):
91+
self.custom_payload = None
92+
self.tracing = False
9293

9394
def update_custom_payload(self, other):
9495
if other:
@@ -102,6 +103,11 @@ def __repr__(self):
102103
return '<%s(%s)>' % (self.__class__.__name__, ', '.join('%s=%r' % i for i in _get_params(self)))
103104

104105

106+
class _DecodableMessageType(_MessageType):
107+
"""Base class for messages that can be decoded and receive protocol attributes"""
108+
__slots__ = ('stream_id', 'trace_id', 'warnings')
109+
110+
105111
def _get_params(message_obj):
106112
base_attrs = dir(_MessageType)
107113
return (
@@ -113,7 +119,7 @@ def _get_params(message_obj):
113119
error_classes = {}
114120

115121

116-
class ErrorMessage(_MessageType, Exception):
122+
class ErrorMessage(Exception):
117123
opcode = 0x00
118124
name = 'ERROR'
119125
summary = 'Unknown'
@@ -418,6 +424,7 @@ class StartupMessage(_MessageType):
418424
))
419425

420426
def __init__(self, cqlversion, options):
427+
super().__init__()
421428
self.cqlversion = cqlversion
422429
self.options = options
423430

@@ -427,7 +434,9 @@ def send_body(self, f, protocol_version):
427434
write_stringmap(f, optmap)
428435

429436

430-
class ReadyMessage(_MessageType):
437+
class ReadyMessage(_DecodableMessageType):
438+
__slots__ = ()
439+
431440
opcode = 0x02
432441
name = 'READY'
433442

@@ -436,11 +445,14 @@ def recv_body(cls, *args):
436445
return cls()
437446

438447

439-
class AuthenticateMessage(_MessageType):
448+
class AuthenticateMessage(_DecodableMessageType):
449+
__slots__ = ('authenticator',)
450+
440451
opcode = 0x03
441452
name = 'AUTHENTICATE'
442453

443454
def __init__(self, authenticator):
455+
super().__init__()
444456
self.authenticator = authenticator
445457

446458
@classmethod
@@ -454,6 +466,7 @@ class CredentialsMessage(_MessageType):
454466
name = 'CREDENTIALS'
455467

456468
def __init__(self, creds):
469+
super().__init__()
457470
self.creds = creds
458471

459472
def send_body(self, f, protocol_version):
@@ -468,11 +481,14 @@ def send_body(self, f, protocol_version):
468481
write_string(f, credval)
469482

470483

471-
class AuthChallengeMessage(_MessageType):
484+
class AuthChallengeMessage(_DecodableMessageType):
485+
__slots__ = ('challenge',)
486+
472487
opcode = 0x0E
473488
name = 'AUTH_CHALLENGE'
474489

475490
def __init__(self, challenge):
491+
super().__init__()
476492
self.challenge = challenge
477493

478494
@classmethod
@@ -485,17 +501,21 @@ class AuthResponseMessage(_MessageType):
485501
name = 'AUTH_RESPONSE'
486502

487503
def __init__(self, response):
504+
super().__init__()
488505
self.response = response
489506

490507
def send_body(self, f, protocol_version):
491508
write_longstring(f, self.response)
492509

493510

494-
class AuthSuccessMessage(_MessageType):
511+
class AuthSuccessMessage(_DecodableMessageType):
512+
__slots__ = ('token',)
513+
495514
opcode = 0x10
496515
name = 'AUTH_SUCCESS'
497516

498517
def __init__(self, token):
518+
super().__init__()
499519
self.token = token
500520

501521
@classmethod
@@ -511,11 +531,14 @@ def send_body(self, f, protocol_version):
511531
pass
512532

513533

514-
class SupportedMessage(_MessageType):
534+
class SupportedMessage(_DecodableMessageType):
535+
__slots__ = ('cql_versions', 'options')
536+
515537
opcode = 0x06
516538
name = 'SUPPORTED'
517539

518540
def __init__(self, cql_versions, options):
541+
super().__init__()
519542
self.cql_versions = cql_versions
520543
self.options = options
521544

@@ -541,11 +564,15 @@ def recv_body(cls, f, *args):
541564

542565

543566
class _QueryMessage(_MessageType):
544-
545-
def __init__(self, query_params, consistency_level,
546-
serial_consistency_level=None, fetch_size=None,
547-
paging_state=None, timestamp=None, skip_meta=False,
548-
continuous_paging_options=None, keyspace=None):
567+
__slots__ = ('query_params', 'consistency_level', 'serial_consistency_level',
568+
'fetch_size', 'paging_state', 'skip_meta', 'timestamp', 'keyspace')
569+
570+
def __init__(self, query_params, consistency_level, serial_consistency_level=None,
571+
fetch_size=None, paging_state=None, skip_meta=False,
572+
timestamp=None, keyspace=None, continuous_paging_options=None):
573+
super().__init__()
574+
# Note: continuous_paging_options is accepted for backward compatibility
575+
# but is not currently implemented (not stored or used)
549576
self.query_params = query_params
550577
self.consistency_level = consistency_level
551578
self.serial_consistency_level = serial_consistency_level
@@ -607,32 +634,46 @@ def _write_paging_options(self, f, paging_options, protocol_version):
607634

608635

609636
class QueryMessage(_QueryMessage):
637+
__slots__ = ('query',)
638+
610639
opcode = 0x07
611640
name = 'QUERY'
612641

613642
def __init__(self, query, consistency_level, serial_consistency_level=None,
614643
fetch_size=None, paging_state=None, timestamp=None, continuous_paging_options=None, keyspace=None):
644+
# Note: continuous_paging_options is accepted for backward compatibility
645+
# but is not currently implemented (not stored or used)
615646
self.query = query
616-
super(QueryMessage, self).__init__(None, consistency_level, serial_consistency_level, fetch_size,
617-
paging_state, timestamp, False, continuous_paging_options, keyspace)
647+
super(QueryMessage, self).__init__(query_params=None, consistency_level=consistency_level,
648+
serial_consistency_level=serial_consistency_level,
649+
fetch_size=fetch_size, paging_state=paging_state,
650+
skip_meta=False, timestamp=timestamp, keyspace=keyspace,
651+
continuous_paging_options=continuous_paging_options)
618652

619653
def send_body(self, f, protocol_version):
620654
write_longstring(f, self.query)
621655
self._write_query_params(f, protocol_version)
622656

623657

624658
class ExecuteMessage(_QueryMessage):
659+
__slots__ = ('query_id', 'result_metadata_id')
660+
625661
opcode = 0x0A
626662
name = 'EXECUTE'
627663

628664
def __init__(self, query_id, query_params, consistency_level,
629665
serial_consistency_level=None, fetch_size=None,
630666
paging_state=None, timestamp=None, skip_meta=False,
631667
continuous_paging_options=None, result_metadata_id=None):
668+
# Note: continuous_paging_options is accepted for backward compatibility
669+
# but is not currently implemented (not stored or used)
632670
self.query_id = query_id
633671
self.result_metadata_id = result_metadata_id
634-
super(ExecuteMessage, self).__init__(query_params, consistency_level, serial_consistency_level, fetch_size,
635-
paging_state, timestamp, skip_meta, continuous_paging_options)
672+
super(ExecuteMessage, self).__init__(query_params=query_params, consistency_level=consistency_level,
673+
serial_consistency_level=serial_consistency_level,
674+
fetch_size=fetch_size, paging_state=paging_state,
675+
skip_meta=skip_meta, timestamp=timestamp, keyspace=None,
676+
continuous_paging_options=continuous_paging_options)
636677

637678
def _write_query_params(self, f, protocol_version):
638679
super(ExecuteMessage, self)._write_query_params(f, protocol_version)
@@ -653,14 +694,14 @@ def send_body(self, f, protocol_version):
653694
RESULT_KIND_SCHEMA_CHANGE = 0x0005
654695

655696

656-
class ResultMessage(_MessageType):
697+
class ResultMessage(_DecodableMessageType):
698+
__slots__ = ('kind', 'result_metadata_id', 'results', 'paging_state', 'column_names', 'column_types',
699+
'parsed_rows', 'continuous_paging_seq', 'continuous_paging_last', 'new_keyspace',
700+
'column_metadata', 'query_id', 'bind_metadata', 'pk_indexes', 'schema_change_event', 'is_lwt')
701+
657702
opcode = 0x08
658703
name = 'RESULT'
659704

660-
kind = None
661-
results = None
662-
paging_state = None
663-
664705
# Names match type name in module scope. Most are imported from cassandra.cqltypes (except CUSTOM_TYPE)
665706
type_codes = _cqltypes_by_code = dict((v, globals()[k]) for k, v in type_codes.__dict__.items() if not k.startswith('_'))
666707

@@ -671,25 +712,25 @@ class ResultMessage(_MessageType):
671712
_CONTINUOUS_PAGING_LAST_FLAG = 0x80000000
672713
_METADATA_ID_FLAG = 0x0008
673714

674-
kind = None
675-
676-
# These are all the things a result message might contain. They are populated according to 'kind'
677-
column_names = None
678-
column_types = None
679-
parsed_rows = None
680-
paging_state = None
681-
continuous_paging_seq = None
682-
continuous_paging_last = None
683-
new_keyspace = None
684-
column_metadata = None
685-
query_id = None
686-
bind_metadata = None
687-
pk_indexes = None
688-
schema_change_event = None
689-
is_lwt = False
690-
691715
def __init__(self, kind):
716+
super().__init__()
692717
self.kind = kind
718+
# Initialize all slot attributes to None
719+
self.result_metadata_id = None
720+
self.results = None
721+
self.paging_state = None
722+
self.column_names = None
723+
self.column_types = None
724+
self.parsed_rows = None
725+
self.continuous_paging_seq = None
726+
self.continuous_paging_last = None
727+
self.new_keyspace = None
728+
self.column_metadata = None
729+
self.query_id = None
730+
self.bind_metadata = None
731+
self.pk_indexes = None
732+
self.schema_change_event = None
733+
self.is_lwt = None
693734

694735
def recv(self, f, protocol_version, protocol_features, user_type_map, result_metadata, column_encryption_policy):
695736
if self.kind == RESULT_KIND_VOID:
@@ -859,10 +900,13 @@ def recv_row(f, colcount):
859900

860901

861902
class PrepareMessage(_MessageType):
903+
__slots__ = ('query', 'keyspace')
904+
862905
opcode = 0x09
863906
name = 'PREPARE'
864907

865908
def __init__(self, query, keyspace=None):
909+
super().__init__()
866910
self.query = query
867911
self.keyspace = keyspace
868912

@@ -897,12 +941,15 @@ def send_body(self, f, protocol_version):
897941

898942

899943
class BatchMessage(_MessageType):
944+
__slots__ = ('batch_type', 'queries', 'consistency_level', 'serial_consistency_level',
945+
'timestamp', 'keyspace')
946+
900947
opcode = 0x0D
901948
name = 'BATCH'
902949

903950
def __init__(self, batch_type, queries, consistency_level,
904-
serial_consistency_level=None, timestamp=None,
905-
keyspace=None):
951+
serial_consistency_level=None, timestamp=None, keyspace=None):
952+
super().__init__()
906953
self.batch_type = batch_type
907954
self.queries = queries
908955
self.consistency_level = consistency_level
@@ -962,21 +1009,27 @@ def send_body(self, f, protocol_version):
9621009

9631010

9641011
class RegisterMessage(_MessageType):
1012+
__slots__ = ('event_list',)
1013+
9651014
opcode = 0x0B
9661015
name = 'REGISTER'
9671016

9681017
def __init__(self, event_list):
1018+
super().__init__()
9691019
self.event_list = event_list
9701020

9711021
def send_body(self, f, protocol_version):
9721022
write_stringlist(f, self.event_list)
9731023

9741024

975-
class EventMessage(_MessageType):
1025+
class EventMessage(_DecodableMessageType):
1026+
__slots__ = ('event_type', 'event_args')
1027+
9761028
opcode = 0x0C
9771029
name = 'EVENT'
9781030

9791031
def __init__(self, event_type, event_args):
1032+
super().__init__()
9801033
self.event_type = event_type
9811034
self.event_args = event_args
9821035

@@ -1038,6 +1091,7 @@ class RevisionType(object):
10381091
name = 'REVISE_REQUEST'
10391092

10401093
def __init__(self, op_type, op_id, next_pages=0):
1094+
super().__init__()
10411095
self.op_type = op_type
10421096
self.op_id = op_id
10431097
self.next_pages = next_pages

0 commit comments

Comments
 (0)