Skip to content

Commit 07b63bf

Browse files
authored
Updates to kafka.protocol.new to improve compatibility with old protocol (#2735)
1 parent 8839d52 commit 07b63bf

9 files changed

Lines changed: 326 additions & 254 deletions

File tree

kafka/protocol/new/api_message.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,26 @@ def __getitem__(cls, version):
2424
# Use [] lookups to move from primary class to "versioned" classes
2525
# which are simple wrappers around the primary class but with a _version attr
2626
if cls._class_version is not None:
27-
return cls._VERSIONS[None].__getitem__(version)
27+
return cls._VERSIONS[None][version]
28+
if cls._valid_versions is not None:
29+
if version < 0:
30+
version += 1 + cls.max_version # support negative index, e.g., [-1]
31+
if not cls.min_version <= version <= cls.max_version:
32+
raise ValueError('Invalid version! min=%d, max=%d' % (cls.min_version, cls.max_version))
2833
klass_name = cls.__name__ + '_v' + str(version)
2934
if klass_name in cls._VERSIONS:
3035
return cls._VERSIONS[klass_name]
3136
cls._VERSIONS[klass_name] = type(klass_name, tuple(cls.mro()), {'_class_version': version}, init=False)
3237
return cls._VERSIONS[klass_name]
3338

39+
def __len__(cls):
40+
# Maintain compatibility
41+
if cls._valid_versions is None:
42+
raise RuntimeError('Unable to calculate __len__ for class without valid_versions')
43+
elif cls._class_version is not None:
44+
raise TypeError('len() only supported on primary message class (not versioned)')
45+
return cls._valid_versions[1] + 1
46+
3447

3548
class ApiMessageMeta(VersionSubscriptable, SlotsBuilder):
3649
def __new__(metacls, name, bases, attrs, **kw):
@@ -58,7 +71,7 @@ def __init__(cls, name, bases, attrs, **kw):
5871

5972

6073
class ApiMessage(DataContainer, metaclass=ApiMessageMeta, init=False):
61-
__slots__ = ('_header', '_version')
74+
__slots__ = ('_header')
6275

6376
def __init_subclass__(cls, **kw):
6477
super().__init_subclass__(**kw)
@@ -72,11 +85,17 @@ def __init_subclass__(cls, **kw):
7285
ResponseClassRegistry.register_response_class(weakref.proxy(cls))
7386

7487
def __init__(self, *args, **kwargs):
75-
super().__init__(*args, **kwargs)
7688
self._header = None
7789
self._version = None
7890
if 'version' in kwargs:
7991
self.API_VERSION = kwargs['version']
92+
if len(args) > 0:
93+
untagged_fields = self._struct.untagged_fields(self.API_VERSION)
94+
if len(args) != len(untagged_fields):
95+
raise RuntimeError('Unable to init ApiMessage via positional args: unexpected len')
96+
kwargs.update({field.name: args[i] for i, field in enumerate(untagged_fields)})
97+
args = ()
98+
super().__init__(*args, **kwargs)
8099

81100
@classproperty
82101
def name(cls): # pylint: disable=E0213
@@ -171,7 +190,13 @@ def encode_header(self, flexible=False):
171190
return self._header.encode(flexible=flexible) # pylint: disable=E1120
172191

173192
@classmethod
174-
def parse_header(cls, data, flexible=False):
193+
def parse_header(cls, data, version=None):
194+
version = cls._class_version if version is None else version
195+
if version is None:
196+
raise ValueError('Version required to decode data')
197+
elif not 0 <= version <= cls.max_version:
198+
raise ValueError('Invalid version %s (max version is %s).' % (version, cls.max_version))
199+
flexible = cls.flexible_version_q(version)
175200
return cls.header_class.decode(data, flexible=flexible) # pylint: disable=E1101
176201

177202
def encode(self, version=None, header=False, framed=False):
@@ -206,15 +231,15 @@ def decode(cls, data, version=None, header=False, framed=False):
206231
else:
207232
data_class = cls
208233

209-
flexible = cls.flexible_version_q(version)
210234
if isinstance(data, bytes):
211235
data = io.BytesIO(data)
212236
if framed:
213237
size = Int32.decode(data)
214238
if header:
215-
hdr = cls.parse_header(data, flexible=flexible)
239+
hdr = cls.parse_header(data, version=version)
216240
else:
217241
hdr = None
242+
flexible = cls.flexible_version_q(version)
218243
ret = cls._struct.decode(data, version=version, compact=flexible, tagged=flexible, data_class=data_class)
219244
if hdr is not None:
220245
ret._header = hdr

kafka/protocol/new/data_container.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def __new__(metacls, name, bases, attrs, **kw):
99

1010

1111
class DataContainer(metaclass=SlotsBuilder):
12-
__slots__ = ('tags', 'unknown_tags')
12+
__slots__ = ('tags', 'unknown_tags', '_version')
1313
_struct = None
1414

1515
def __init_subclass__(cls, **kwargs):
@@ -22,8 +22,9 @@ def __init_subclass__(cls, **kwargs):
2222
field.set_data_class(type(field.type_str, (DataContainer,), {'_struct': field}))
2323
setattr(cls, field.type_str, field.data_class)
2424

25-
def __init__(self, **field_vals):
25+
def __init__(self, version=None, **field_vals):
2626
assert self._struct is not None
27+
self._version = version
2728
self.tags = None
2829
self.unknown_tags = None
2930
for field in self._struct._fields:
@@ -83,3 +84,21 @@ def __eq__(self, other):
8384
if getattr(self, field.name) != getattr(other, field.name):
8485
return False
8586
return True
87+
88+
def __iter__(self):
89+
if self._version is None:
90+
raise RuntimeError('DataContainer Iteration not supported without _version')
91+
return iter([getattr(self, field.name) for field in self._struct.untagged_fields(self._version)])
92+
93+
def __getitem__(self, key):
94+
if self._version is None:
95+
raise RuntimeError('DataContainer subscript not supported without _version')
96+
elif isinstance(key, int):
97+
field = self._struct.untagged_fields(self._version)[key]
98+
return getattr(self, field.name)
99+
elif isinstance(key, slice):
100+
fields = self._struct.untagged_fields(self._version)
101+
start, stop, step = key.indices(len(fields))
102+
return [getattr(self, fields[i].name) for i in range(start, stop, step)]
103+
else:
104+
raise TypeError('DataContainer subscript supports int or slices only: %s' % type(key).__name__)

kafka/protocol/new/metadata/api_versions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ class ApiVersionsRequest(ApiMessage): pass
88
class ApiVersionsResponse(ApiMessage):
99
# ApiVersionsResponse header never uses flexible formats, even if body does
1010
@classmethod
11-
def parse_header(cls, data, flexible=False):
12-
return super().parse_header(data, flexible=False)
11+
def parse_header(cls, data, version=None):
12+
return cls.header_class.decode(data, flexible=False) # pylint: disable=E1101
1313

1414
def encode_header(self, flexible=False):
1515
return super().encode_header(flexible=False)

kafka/protocol/new/metadata/metadata.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,21 @@
11
from ..api_message import ApiMessage
22

3+
from kafka.util import classproperty
4+
5+
6+
class MetadataRequest(ApiMessage):
7+
@classproperty
8+
def ALL_TOPICS(cls): # pylint: disable=E0213
9+
if cls._class_version == 0: # pylint: disable=E1101
10+
return []
11+
else:
12+
return None
13+
14+
@classproperty
15+
def NO_TOPICS(cls): # pylint: disable=E0213
16+
return []
17+
318

4-
class MetadataRequest(ApiMessage): pass
519
class MetadataResponse(ApiMessage):
620
@classmethod
721
def json_patch(cls, json):

kafka/protocol/new/schemas/fields/array.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,6 @@ def decode(self, data, version=None, compact=False, tagged=False):
7373
return None
7474
return [self.array_of.decode(data, version=version, compact=compact, tagged=tagged)
7575
for _ in range(size)]
76+
77+
def __repr__(self):
78+
return 'ArrayField(%s)' % self._json

kafka/protocol/new/schemas/fields/simple.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,12 @@ def decode(self, data, version=None, compact=False, tagged=False):
101101
assert version is not None, 'version is required to decode Field'
102102
if not self.for_version_q(version):
103103
return None
104-
print("decoding", self.name)
105104
if compact and self._type is Bytes:
106105
return CompactBytes.decode(data)
107106
elif compact and isinstance(self._type, String):
108107
return CompactString(self._type.encoding).decode(data)
109108
else:
110109
return self._type.decode(data)
110+
111+
def __repr__(self):
112+
return 'SimpleField(%s)' % self._json

kafka/protocol/new/schemas/fields/struct.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,17 +55,26 @@ def tagged_fields(self, version):
5555
if field.for_version_q(version)
5656
and field.tagged_field_q(version)])
5757

58+
def untagged_fields(self, version):
59+
return [field for field in self._fields
60+
if field.for_version_q(version)
61+
and not field.tagged_field_q(version)]
62+
5863
def encode(self, item, version=None, compact=False, tagged=False):
5964
assert version is not None, 'version required to encode StructField'
6065
if not self.for_version_q(version):
6166
return b''
62-
fields = [field for field in self._fields if field.for_version_q(version) and not field.tagged_field_q(version)]
67+
fields = self.untagged_fields(version)
6368
if isinstance(item, tuple):
6469
getter = lambda item, i, field: item[i]
6570
tags = {} if len(item) == len(fields) else item[-1]
6671
elif isinstance(item, dict):
6772
getter = lambda item, i, field: item.get(field.name) # defaults?
6873
tags = item
74+
elif isinstance(item, (str, int, float)):
75+
assert len(fields) == 1, "Encoding single value item (str/int/float) requires single field struct"
76+
getter = lambda item, i, field: item
77+
tags = {}
6978
else:
7079
getter = lambda item, i, field: getattr(item, field.name)
7180
tags = item
@@ -75,7 +84,7 @@ def encode(self, item, version=None, compact=False, tagged=False):
7584
if tagged:
7685
# TaggedFields are always compact and never include nested tagged fields
7786
encoded.append(self.tagged_fields(version).encode(tags, version=version,
78-
compact=True, tagged=False))
87+
compact=True, tagged=False))
7988
return b''.join(encoded)
8089

8190
def decode(self, data, version=None, compact=False, tagged=False, data_class=None):
@@ -91,7 +100,7 @@ def decode(self, data, version=None, compact=False, tagged=False, data_class=Non
91100
}
92101
if tagged:
93102
decoded.update(self.tagged_fields(version).decode(data, version=version, compact=True, tagged=False))
94-
return data_class(**decoded)
103+
return data_class(version=version, **decoded)
95104

96105
def __len__(self):
97106
return len(self._fields)
@@ -104,4 +113,4 @@ def __eq__(self, other):
104113
return True
105114

106115
def __repr__(self):
107-
return '%s(%s, %s)' % (self.__class__.__name__, self._name, self._fields)
116+
return 'StructField(%s)' % self._json

kafka/protocol/new/schemas/fields/struct_array.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ def is_struct_array(self):
4040
def fields(self):
4141
return self.array_of.fields
4242

43+
def tagged_fields(self, version):
44+
return self.array_of.tagged_fields(version)
45+
46+
def untagged_fields(self, version):
47+
return self.array_of.untagged_fields(version)
48+
4349
def has_data_class(self):
4450
return self.array_of.has_data_class()
4551

@@ -52,3 +58,6 @@ def data_class(self):
5258

5359
def __call__(self, *args, **kw):
5460
return self.data_class(*args, **kw) # pylint: disable=E1102
61+
62+
def __repr__(self):
63+
return 'StructArrayField(%s)' % self._json

0 commit comments

Comments
 (0)