Skip to content

Commit d1f2750

Browse files
committed
Retain header on decode; test header+msg equality
1 parent 3e5b214 commit d1f2750

3 files changed

Lines changed: 32 additions & 17 deletions

File tree

kafka/protocol/api.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ def API_VERSION(self):
5858
def to_object(self):
5959
return _to_object(self.SCHEMA, self)
6060

61+
@classmethod
62+
@abc.abstractmethod
63+
def is_request(cls):
64+
pass
65+
6166
@property
6267
def header(self):
6368
return self._header
@@ -88,13 +93,19 @@ def decode(cls, data, header=False, framed=False):
8893
return super().decode(data)
8994
if isinstance(data, bytes):
9095
data = BytesIO(data)
91-
ret = []
9296
if framed:
93-
ret.append(Int32.decode(data))
97+
size = Int32.decode(data)
9498
if header:
95-
ret.append(cls.parse_header(data))
96-
ret.append(super().decode(data))
97-
return tuple(ret)
99+
hdr = cls.parse_header(data)
100+
else:
101+
hdr = None
102+
ret = super().decode(data)
103+
if hdr is not None:
104+
ret._header = hdr
105+
return ret
106+
107+
def __eq__(self, other):
108+
return self._header == other._header and super().__eq__(other)
98109

99110

100111
class Request(RequestResponse):
@@ -103,6 +114,10 @@ def RESPONSE_TYPE(self):
103114
"""The Response class associated with the api request"""
104115
pass
105116

117+
@classmethod
118+
def is_request(cls):
119+
return True
120+
106121
def expect_response(self):
107122
"""Override this method if an api request does not always generate a response"""
108123
return True
@@ -120,13 +135,12 @@ def header_class(cls):
120135
else:
121136
return RequestHeader
122137

123-
def encode(self, header=False, framed=False, correlation_id=None, client_id=None, **kwargs):
124-
if header and self.header is None:
125-
self.with_header(correlation_id=correlation_id, client_id=client_id)
126-
return super().encode(header=header, framed=framed)
127-
128138

129139
class Response(RequestResponse):
140+
@classmethod
141+
def is_request(cls):
142+
return False
143+
130144
def with_header(self, correlation_id=0):
131145
if self.FLEXIBLE_VERSION:
132146
self._header = self.header_class()(correlation_id, {})
@@ -140,11 +154,6 @@ def header_class(cls):
140154
else:
141155
return ResponseHeader
142156

143-
def encode(self, header=False, framed=False, correlation_id=None, **kwargs):
144-
if header and self.header is None:
145-
self.with_header(correlation_id=correlation_id)
146-
return super().encode(header=header, framed=framed)
147-
148157

149158
def _to_object(schema, data):
150159
obj = {}

kafka/protocol/struct.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def __hash__(self):
5454
return hash(self.encode())
5555

5656
def __eq__(self, other):
57+
if not isinstance(other, Struct):
58+
return False
5759
if self.SCHEMA != other.SCHEMA:
5860
return False
5961
for attr in self.SCHEMA.names:

test/protocol/test_api_versions.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,5 +65,9 @@
6565

6666
@pytest.mark.parametrize('msg, encoded', TEST_CASES)
6767
def test_parse(msg, encoded):
68-
assert msg.encode(correlation_id=1, client_id='_internal_client_kYVL', header=True, framed=True) == encoded
69-
assert msg.decode(encoded, header=True, framed=True)[2] == msg
68+
if msg.is_request():
69+
msg.with_header(correlation_id=1, client_id='_internal_client_kYVL')
70+
else:
71+
msg.with_header(correlation_id=1)
72+
assert msg.encode(header=True, framed=True) == encoded
73+
assert msg.decode(encoded, header=True, framed=True) == msg

0 commit comments

Comments
 (0)