Skip to content

Commit 38766f2

Browse files
authored
Add ApiData, and consolidate metaclasses into JsonSchemaData (#2753)
1 parent 20bb49e commit 38766f2

3 files changed

Lines changed: 131 additions & 28 deletions

File tree

kafka/protocol/new/api_data.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import io
2+
import weakref
3+
4+
from kafka.util import classproperty
5+
6+
from .data_container import DataContainer, SlotsBuilder
7+
from .schemas import BaseField, StructField, load_json
8+
from .schemas.fields.codecs import Int16, Int32
9+
10+
11+
class JsonSchemaData(SlotsBuilder):
12+
def __new__(metacls, name, bases, attrs, **kw):
13+
if kw.get('init', True):
14+
json = load_json(name)
15+
if 'json_patch' in attrs:
16+
json = attrs['json_patch'].__func__(metacls, json)
17+
attrs['_json'] = json
18+
attrs['_struct'] = StructField(json)
19+
if 'doc' in json:
20+
attrs['__doc__'] = attrs.get('__doc__', '') + "\nNotes from json schema:\n" + json.get('doc')
21+
attrs['__license__'] = json.get('license')
22+
return super().__new__(metacls, name, bases, attrs, **kw)
23+
24+
def __init__(cls, name, bases, attrs, **kw):
25+
super().__init__(name, bases, attrs, **kw)
26+
if kw.get('init', True):
27+
cls._struct.set_data_class(weakref.proxy(cls))
28+
29+
30+
class ApiData(DataContainer, metaclass=JsonSchemaData, init=False):
31+
def __init_subclass__(cls, **kw):
32+
super().__init_subclass__(**kw)
33+
if kw.get('init', True):
34+
# pylint: disable=E1101
35+
assert cls._json is not None
36+
assert cls._json['type'] == 'data'
37+
cls._flexible_versions = BaseField.parse_versions(cls._json['flexibleVersions'])
38+
cls._valid_versions = BaseField.parse_versions(cls._json['validVersions'])
39+
40+
def __init__(self, *args, **kwargs):
41+
if len(args) > 0 and isinstance(args[0], int) and 'version' not in kwargs:
42+
kwargs['version'] = args[0]
43+
args = tuple(args[1:])
44+
super().__init__(*args, **kwargs)
45+
46+
@classproperty
47+
def name(cls): # pylint: disable=E0213
48+
return cls._json['name'] # pylint: disable=E1101
49+
50+
@classproperty
51+
def type(cls): # pylint: disable=E0213
52+
return cls._json['type'] # pylint: disable=E1101
53+
54+
@classproperty
55+
def json(cls): # pylint: disable=E0213
56+
return cls._json # pylint: disable=E1101
57+
58+
@classproperty
59+
def valid_versions(cls): # pylint: disable=E0213
60+
return cls._valid_versions
61+
62+
@classproperty
63+
def min_version(cls): # pylint: disable=E0213
64+
return 0
65+
66+
@classproperty
67+
def max_version(cls): # pylint: disable=E0213
68+
if cls._valid_versions is not None:
69+
return cls._valid_versions[1] # pylint: disable=E1136
70+
return None
71+
72+
@classmethod
73+
def flexible_version_q(cls, version):
74+
if cls._flexible_versions is not None:
75+
if cls._flexible_versions[0] <= version <= cls._flexible_versions[1]: # pylint: disable=E1136
76+
return True
77+
return False
78+
79+
@classproperty
80+
def header_class(cls): # pylint: disable=E0213
81+
return Int16
82+
83+
def encode_header(self, flexible=False):
84+
assert self._version is not None
85+
return self.header_class.encode(self._version)
86+
87+
@classmethod
88+
def parse_header(cls, data):
89+
return cls.header_class.decode(data) # pylint: disable-msg=no-member
90+
91+
def encode(self, version=None, header=True, framed=False):
92+
if version is not None:
93+
self._version = version
94+
elif self._version is None:
95+
raise ValueError('Version required to encode data')
96+
flexible = self.flexible_version_q(self._version)
97+
encoded = self._struct.encode(self, version=self._version, compact=flexible, tagged=flexible)
98+
if not header and not framed:
99+
return encoded
100+
bits = [encoded]
101+
if header:
102+
bits.insert(0, self.encode_header(flexible=flexible))
103+
if framed:
104+
bits.insert(0, Int32.encode(sum(map(len, bits))))
105+
return b''.join(bits)
106+
107+
@classmethod
108+
def decode(cls, data, version=None, header=True, framed=False):
109+
if not header:
110+
if version is None:
111+
raise ValueError('Version required to decode data')
112+
elif not 0 <= version <= cls.max_version:
113+
raise ValueError('Invalid version %s (max version is %s).' % (version, cls.max_version))
114+
if isinstance(data, bytes):
115+
data = io.BytesIO(data)
116+
if framed:
117+
size = Int32.decode(data)
118+
if header:
119+
decoded_version = cls.parse_header(data)
120+
if version is not None:
121+
if version > decoded_version:
122+
raise ValueError('Version mismatch: found v%d, expected v%d' % (decoded_version, version))
123+
version = min(decoded_version, cls.max_version)
124+
flexible = cls.flexible_version_q(version)
125+
return cls._struct.decode(data, version=version, compact=flexible, tagged=flexible)

kafka/protocol/new/api_header.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,9 @@
1+
from .api_data import JsonSchemaData
12
from .data_container import DataContainer, SlotsBuilder
23
from .schemas import BaseField, StructField, load_json
34

45

5-
class ApiHeaderMeta(SlotsBuilder):
6-
def __new__(metacls, name, bases, attrs, **kw):
7-
if kw.get('init', True):
8-
json = load_json(name)
9-
attrs['_json'] = json
10-
attrs['_struct'] = StructField(json)
11-
return super().__new__(metacls, name, bases, attrs, **kw)
12-
13-
14-
class ApiHeader(DataContainer, metaclass=ApiHeaderMeta, init=False):
6+
class ApiHeader(DataContainer, metaclass=JsonSchemaData, init=False):
157
__slots__ = ()
168

179
def __init_subclass__(cls, **kw):

kafka/protocol/new/api_message.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import io
22
import weakref
33

4+
from .api_data import JsonSchemaData
45
from .api_header import RequestHeader, ResponseHeader, ResponseClassRegistry
5-
from .data_container import DataContainer, SlotsBuilder
6+
from .data_container import DataContainer
67
from .schemas import BaseField, StructField, load_json
78
from .schemas.fields.codecs import Int32
89

@@ -48,32 +49,17 @@ def __len__(cls):
4849
return cls._valid_versions[1] + 1
4950

5051

51-
class ApiMessageMeta(VersionSubscriptable, SlotsBuilder):
52-
def __new__(metacls, name, bases, attrs, **kw):
53-
# Pass init=False from base classes
54-
if kw.get('init', True):
55-
json = load_json(name)
56-
if 'json_patch' in attrs:
57-
json = attrs['json_patch'].__func__(metacls, json)
58-
attrs['_json'] = json
59-
attrs['_struct'] = StructField(json)
60-
attrs['__doc__'] = json.get('doc')
61-
attrs['__license__'] = json.get('license')
62-
return super().__new__(metacls, name, bases, attrs, **kw)
63-
52+
class ApiMessageData(VersionSubscriptable, JsonSchemaData):
6453
def __init__(cls, name, bases, attrs, **kw):
6554
super().__init__(name, bases, attrs, **kw)
6655
if kw.get('init', True):
6756
# Ignore min valid version on request/response schemas
6857
# We'll get the brokers supported versions via ApiVersionsRequest
6958
if cls._struct._versions[0] > 0:
7059
cls._struct._versions = (0, cls._struct._versions[1])
71-
# Configure the StructField to use our ApiMessage wrapper
72-
# and not construct a default DataContainer
73-
cls._struct.set_data_class(weakref.proxy(cls))
7460

7561

76-
class ApiMessage(DataContainer, metaclass=ApiMessageMeta, init=False):
62+
class ApiMessage(DataContainer, metaclass=ApiMessageData, init=False):
7763
__slots__ = ('_header')
7864

7965
def __init_subclass__(cls, **kw):

0 commit comments

Comments
 (0)