Skip to content

Commit 2516b7b

Browse files
committed
Added support for streaming multipart decoding
1 parent 31e8a16 commit 2516b7b

4 files changed

Lines changed: 279 additions & 8 deletions

File tree

requests_toolbelt/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
from .auth.guess import GuessAuth
1414
from .multipart import (
1515
MultipartEncoder, MultipartEncoderMonitor, MultipartDecoder,
16-
ImproperBodyPartContentException, NonMultipartContentTypeException
16+
ImproperBodyPartContentException, NonMultipartContentTypeException,
17+
MultipartStreamDecoder
1718
)
1819
from .streaming_iterator import StreamingIterator
1920
from .utils.user_agent import user_agent
@@ -31,4 +32,5 @@
3132
'StreamingIterator', 'user_agent', 'ImproperBodyPartContentException',
3233
'NonMultipartContentTypeException', '__title__', '__authors__',
3334
'__license__', '__copyright__', '__version__', '__version_info__',
35+
'MultipartStreamDecoder'
3436
]

requests_toolbelt/multipart/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"""
1010

1111
from .encoder import MultipartEncoder, MultipartEncoderMonitor
12-
from .decoder import MultipartDecoder
12+
from .decoder import MultipartDecoder, MultipartStreamDecoder
1313
from .decoder import ImproperBodyPartContentException
1414
from .decoder import NonMultipartContentTypeException
1515

@@ -22,6 +22,7 @@
2222
'MultipartEncoder',
2323
'MultipartEncoderMonitor',
2424
'MultipartDecoder',
25+
'MultipartStreamDecoder',
2526
'ImproperBodyPartContentException',
2627
'NonMultipartContentTypeException',
2728
'__title__',

requests_toolbelt/multipart/decoder.py

Lines changed: 212 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,12 @@ def __init__(self, content, content_type, encoding='utf-8'):
107107
self.encoding = encoding
108108
#: Parsed parts of the multipart response body
109109
self.parts = tuple()
110-
self._find_boundary()
110+
self.boundary = MultipartDecoder._find_boundary(content_type, encoding)
111111
self._parse_body(content)
112112

113-
def _find_boundary(self):
114-
ct_info = tuple(x.strip() for x in self.content_type.split(';'))
113+
@staticmethod
114+
def _find_boundary(content_type, encoding):
115+
ct_info = tuple(x.strip() for x in content_type.split(';'))
115116
mimetype = ct_info[0]
116117
if mimetype.split('/')[0].lower() != 'multipart':
117118
raise NonMultipartContentTypeException(
@@ -123,7 +124,8 @@ def _find_boundary(self):
123124
'='
124125
)
125126
if attr.lower() == 'boundary':
126-
self.boundary = encode_with(value.strip('"'), self.encoding)
127+
boundary = encode_with(value.strip('"'), encoding)
128+
return boundary
127129

128130
@staticmethod
129131
def _fix_first_part(part, boundary_marker):
@@ -154,3 +156,209 @@ def from_response(cls, response, encoding='utf-8'):
154156
content = response.content
155157
content_type = response.headers.get('content-type', None)
156158
return cls(content, content_type, encoding)
159+
160+
161+
# This is thrown when the object is being reiterated from another place
162+
class AlreadyIteratedException(Exception):
163+
pass
164+
165+
166+
# This is thrown when trying to skip to the next part without finishing to
167+
# stream the previous one
168+
class PreviousPartNotFinishedException(Exception):
169+
pass
170+
171+
172+
class StreamPart(object):
173+
def __init__(self, headers, encoding, iterator):
174+
self.headers = headers
175+
self.encoding = encoding
176+
self._iterator = iterator
177+
self._started = False
178+
self._consumed = False
179+
self._content = None
180+
self._finished = False
181+
182+
def __iter__(self):
183+
if self._started:
184+
raise AlreadyIteratedException()
185+
self._started = True
186+
for typ, data in self._iterator():
187+
if typ == 'done' and data is False:
188+
self._finished = True
189+
break
190+
elif typ == 'stream':
191+
yield data
192+
else:
193+
raise ImproperBodyPartContentException()
194+
195+
@property
196+
def content(self):
197+
if self._consumed:
198+
return self._content
199+
if self._started:
200+
raise AlreadyIteratedException()
201+
self._content = b''.join(self)
202+
self._consumed = True
203+
return self._content
204+
205+
@property
206+
def text(self):
207+
return self.content.decode(self.encoding)
208+
209+
210+
class MultipartStreamDecoder(object):
211+
@classmethod
212+
def from_response(cls, response, encoding='utf-8', chunk_size=10 * 1024,
213+
header_size_limit=None):
214+
def content():
215+
return response.raw.read(chunk_size)
216+
content_type = response.headers.get('content-type', None)
217+
return cls(content, content_type, encoding, header_size_limit)
218+
219+
def __init__(self, stream_read_func, content_type, encoding='utf-8',
220+
header_size_limit=None):
221+
self.content_type = content_type
222+
self.encoding = encoding
223+
self._stream_read_func = stream_read_func
224+
self._header_size_limit = header_size_limit
225+
self._boundary = MultipartDecoder._find_boundary(content_type,
226+
encoding)
227+
self._splitter = StreamSplitter()
228+
self._boundary = b''.join((b'--', self._boundary))
229+
self._boundary_split = b''.join((b'\r\n', self._boundary))
230+
self._state = 0
231+
self._found = False
232+
self._started = False
233+
self._finished = False
234+
235+
# Call this to drain stream when error occured, or you decide
236+
# not to read all data
237+
def close(self):
238+
if not self._finished:
239+
while True:
240+
try:
241+
data = self._stream_read_func()
242+
if not data:
243+
break
244+
# Protection if _stream_read_func is an generator
245+
except StopIteration:
246+
break
247+
finally:
248+
self._finished = True
249+
250+
# The instance can be used as an context manager for automatic
251+
# draining the stream for re-use
252+
def __enter__(self):
253+
return self
254+
255+
def __exit__(self, exc_type, exc_value, exc_traceback):
256+
self.close()
257+
258+
def __iter__(self):
259+
if self._started:
260+
raise AlreadyIteratedException()
261+
self._started = True
262+
# This is for guarding against iterating before finishing to
263+
# iterate on the current part.
264+
_current_stream = None
265+
for typ, data in self._iterator():
266+
if _current_stream and not _current_stream._finished:
267+
raise PreviousPartNotFinishedException()
268+
if typ == 'headers':
269+
_current_stream = StreamPart(
270+
data, self.encoding, self._iterator
271+
)
272+
yield _current_stream
273+
else:
274+
raise ImproperBodyPartContentException()
275+
276+
def _iterator(self):
277+
while True:
278+
data = self._stream_read_func()
279+
# This persumes that if data returned empty once it won't return
280+
# anything again (EOS)
281+
if not self._found and not data:
282+
self._finished = True
283+
break
284+
# This part mimics the _fix_first_part logic from above
285+
if self._state == 0:
286+
should_be_empty_or_crlf, self._found = self._splitter.stream(
287+
data, self._boundary, True
288+
)
289+
if should_be_empty_or_crlf:
290+
if should_be_empty_or_crlf != b'\r\n':
291+
raise ImproperBodyPartContentException()
292+
self._state = 1
293+
continue
294+
if self._found:
295+
self._state = 1
296+
continue
297+
# Parse the headers
298+
elif self._state == 1:
299+
headers, self._found = self._splitter.stream(data, b'\r\n\r\n',
300+
True)
301+
if headers:
302+
headers = _header_parser(headers.lstrip(), self.encoding)
303+
headers = CaseInsensitiveDict(headers)
304+
self._state = 2
305+
yield 'headers', headers
306+
continue
307+
# No headers found
308+
if self._found:
309+
headers = CaseInsensitiveDict({})
310+
self._state = 2
311+
yield 'headers', headers
312+
continue
313+
# This is to protect against malformed input where a header
314+
# does not exist in a limit for performence reasons
315+
if self._header_size_limit:
316+
if (self._splitter.leftover_length >
317+
self._header_size_limit):
318+
raise ImproperBodyPartContentException()
319+
# Stream the part
320+
elif self._state == 2:
321+
stream, self._found = self._splitter.stream(
322+
data, self._boundary_split
323+
)
324+
if stream:
325+
yield 'stream', stream
326+
# boundary_split found, end of part
327+
if self._found:
328+
self._state = 1
329+
yield 'done', False
330+
continue
331+
332+
333+
class StreamSplitter(object):
334+
def __init__(self):
335+
self.leftover = b''
336+
337+
def stream(self, data, split_data, return_only_full=False):
338+
self.leftover += data
339+
index = self.leftover.find(split_data)
340+
if return_only_full:
341+
if index > -1:
342+
ret = self.leftover[:index]
343+
self.leftover = self.leftover[index + len(split_data):]
344+
found = True
345+
else:
346+
ret = b''
347+
found = False
348+
else:
349+
if index > -1:
350+
ret = self.leftover[:index]
351+
self.leftover = self.leftover[index + len(split_data):]
352+
found = True
353+
elif len(self.leftover) >= len(split_data):
354+
ret = self.leftover[:-len(split_data)]
355+
self.leftover = self.leftover[-len(split_data):]
356+
found = False
357+
else:
358+
ret = b''
359+
found = False
360+
return ret, found
361+
362+
@property
363+
def leftover_length(self):
364+
return len(self.leftover)

tests/test_multipart_decoder.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
from requests_toolbelt.multipart.decoder import (
1010
ImproperBodyPartContentException
1111
)
12-
from requests_toolbelt.multipart.decoder import MultipartDecoder
12+
from requests_toolbelt.multipart.decoder import (
13+
MultipartDecoder, MultipartStreamDecoder
14+
)
1315
from requests_toolbelt.multipart.decoder import (
1416
NonMultipartContentTypeException
1517
)
@@ -104,26 +106,50 @@ def setUp(self):
104106
)
105107
self.boundary = 'test boundary'
106108
self.encoded_1 = MultipartEncoder(self.sample_1, self.boundary)
109+
self.encoded_1_string = self.encoded_1.to_string()
107110
self.decoded_1 = MultipartDecoder(
108-
self.encoded_1.to_string(),
111+
self.encoded_1_string,
112+
self.encoded_1.content_type
113+
)
114+
115+
def CreateMultipartStreamDecoder(self):
116+
data = io.BytesIO(self.encoded_1_string)
117+
118+
def read_function():
119+
return data.read(10)
120+
121+
decoder = MultipartStreamDecoder(
122+
read_function,
109123
self.encoded_1.content_type
110124
)
125+
126+
parts = []
127+
for part in decoder:
128+
part.content
129+
parts.append(part)
130+
131+
return parts
111132

112133
def test_non_multipart_response_fails(self):
113134
jpeg_response = mock.NonCallableMagicMock(spec=requests.Response)
114135
jpeg_response.headers = {'content-type': 'image/jpeg'}
115136
with pytest.raises(NonMultipartContentTypeException):
116137
MultipartDecoder.from_response(jpeg_response)
138+
with pytest.raises(NonMultipartContentTypeException):
139+
MultipartStreamDecoder.from_response(jpeg_response)
117140

118141
def test_length_of_parts(self):
119142
assert len(self.sample_1) == len(self.decoded_1.parts)
143+
assert len(self.sample_1) == len(self.CreateMultipartStreamDecoder())
120144

121145
def test_content_of_parts(self):
122146
def parts_equal(part, sample):
123147
return part.content == encode_with(sample[1], 'utf-8')
124148

125149
parts_iter = zip(self.decoded_1.parts, self.sample_1)
126150
assert all(parts_equal(part, sample) for part, sample in parts_iter)
151+
parts_iter = zip(self.CreateMultipartStreamDecoder(), self.sample_1)
152+
assert all(parts_equal(part, sample) for part, sample in parts_iter)
127153

128154
def test_header_of_parts(self):
129155
def parts_header_equal(part, sample):
@@ -136,6 +162,11 @@ def parts_header_equal(part, sample):
136162
parts_header_equal(part, sample)
137163
for part, sample in parts_iter
138164
)
165+
parts_iter = zip(self.CreateMultipartStreamDecoder(), self.sample_1)
166+
assert all(
167+
parts_header_equal(part, sample)
168+
for part, sample in parts_iter
169+
)
139170

140171
def test_from_response(self):
141172
response = mock.NonCallableMagicMock(spec=requests.Response)
@@ -163,6 +194,21 @@ def test_from_response(self):
163194
assert len(decoder_2.parts[1].headers) == 0
164195
assert decoder_2.parts[1].content == b'Body 2, Line 1'
165196

197+
cnt.seek(0)
198+
response.raw = cnt
199+
decoder_2 = MultipartStreamDecoder.from_response(response)
200+
parts = []
201+
for part in decoder_2:
202+
part.content
203+
parts.append(part)
204+
assert decoder_2.content_type == response.headers['content-type']
205+
assert (
206+
parts[0].content == b'Body 1, Line 1\r\nBody 1, Line 2'
207+
)
208+
assert parts[0].headers[b'Header-1'] == b'Header-Value-1'
209+
assert len(parts[1].headers) == 0
210+
assert parts[1].content == b'Body 2, Line 1'
211+
166212
def test_from_responsecaplarge(self):
167213
response = mock.NonCallableMagicMock(spec=requests.Response)
168214
response.headers = {
@@ -189,3 +235,17 @@ def test_from_responsecaplarge(self):
189235
assert len(decoder_2.parts[1].headers) == 0
190236
assert decoder_2.parts[1].content == b'Body 2, Line 1'
191237

238+
cnt.seek(0)
239+
response.raw = cnt
240+
decoder_2 = MultipartStreamDecoder.from_response(response)
241+
parts = []
242+
for part in decoder_2:
243+
part.content
244+
parts.append(part)
245+
assert decoder_2.content_type == response.headers['content-type']
246+
assert (
247+
parts[0].content == b'Body 1, Line 1\r\nBody 1, Line 2'
248+
)
249+
assert parts[0].headers[b'Header-1'] == b'Header-Value-1'
250+
assert len(parts[1].headers) == 0
251+
assert parts[1].content == b'Body 2, Line 1'

0 commit comments

Comments
 (0)