diff --git a/requests_toolbelt/multipart/encoder.py b/requests_toolbelt/multipart/encoder.py index 2d539617..d2b7d754 100644 --- a/requests_toolbelt/multipart/encoder.py +++ b/requests_toolbelt/multipart/encoder.py @@ -121,12 +121,36 @@ def __init__(self, fields, boundary=None, encoding='utf-8'): # Our buffer self._buffer = CustomBytesIO(encoding=encoding) + # Number of bytes read from the encoder + self._bytes_read = 0 + # Pre-compute each part's headers self._prepare_parts() # Load boundary into buffer self._write_boundary() + def __iter__(self): + # Need to implement iterator protocol otherwise requests won't set + # `is_stream` to `True` and won't rewind the body on redirects. + return self + + def __next__(self): + if self.finished: + raise StopIteration() + return self.read(8192) + + def _reset(self): + """Reset the encoder to the beginning.""" + self.finished = False + for part in self.parts: + part.reset() + self._iter_parts = iter(self.parts) + self._current_part = None + self._buffer = CustomBytesIO(encoding=self.encoding) + self._bytes_read = 0 + self._write_boundary() + @property def len(self): """Length of the multipart/form-data body. @@ -186,7 +210,7 @@ def _calculate_load_amount(self, read_size): def _load(self, amount): """Load ``amount`` number of bytes into the buffer.""" - self._buffer.smart_truncate() + smart_truncate(self._buffer) part = self._current_part or self._next_part() while amount == -1 or amount > 0: written = 0 @@ -304,15 +328,30 @@ def read(self, size=-1): remaining bytes. :returns: bytes """ - if self.finished: - return self._buffer.read(size) + if not self.finished: + bytes_to_load = size + if bytes_to_load != -1 and bytes_to_load is not None: + bytes_to_load = self._calculate_load_amount(int(size)) + + self._load(bytes_to_load) + string = self._buffer.read(size) + self._bytes_read += len(string) + return string - bytes_to_load = size - if bytes_to_load != -1 and bytes_to_load is not None: - bytes_to_load = self._calculate_load_amount(int(size)) + def tell(self): + # type: () -> int + return self._bytes_read - self._load(bytes_to_load) - return self._buffer.read(size) + def seek(self, offset, whence=0): + # type: (int, int) -> int + if (offset, whence) == (0, 0): + self._reset() + elif (offset, whence) == (0, self._bytes_read) or (offset, whence) == (0, 1): + pass + else: + raise io.UnsupportedOperation( + "MultipartEncoder only supports seeking to the beginning") + return self.tell() def IDENTITY(monitor): @@ -377,10 +416,6 @@ def __init__(self, encoder, callback=None): #: Optionally function to call after a read self.callback = callback or IDENTITY - #: Number of bytes already read from the :class:`MultipartEncoder` - #: instance - self.bytes_read = 0 - #: Avoid the same problem in bug #80 self.len = self.encoder.len @@ -394,12 +429,16 @@ def from_fields(cls, fields, boundary=None, encoding='utf-8', def content_type(self): return self.encoder.content_type + @property + def bytes_read(self): + """Number of bytes already read from the :class:`MultipartEncoder` instance.""" + return self.encoder._bytes_read + def to_string(self): return self.read() def read(self, size=-1): string = self.encoder.read(size) - self.bytes_read += len(string) self.callback(self) return string @@ -486,6 +525,10 @@ def __init__(self, headers, body): self.body = body self.headers_unread = True self.len = len(self.headers) + total_len(self.body) + try: + self.initial_pos = body.tell() + except (AttributeError, OSError, NotImplementedError): + self.initial_pos = None @classmethod def from_field(cls, field, encoding): @@ -529,6 +572,20 @@ def write_to(self, buffer, size): return written + def reset(self): + """Reset the part to the beginning.""" + if self.headers_unread: + return + if self.initial_pos is None: + raise io.UnsupportedOperation( + "Underlying body object does not support tell(). Cannot reset.") + try: + self.body.seek(self.initial_pos) + except AttributeError as e: + raise io.UnsupportedOperation( + "Underlying body object does not support seek(). Cannot reset.") from e + self.headers_unread = True + class CustomBytesIO(io.BytesIO): def __init__(self, buffer=None, encoding='utf-8'): @@ -552,16 +609,17 @@ def append(self, bytes): written = self.write(bytes) return written - def smart_truncate(self): - to_be_read = total_len(self) - already_read = self._get_end() - to_be_read - if already_read >= to_be_read: - old_bytes = self.read() - self.seek(0, 0) - self.truncate() - self.write(old_bytes) - self.seek(0, 0) # We want to be at the beginning +def smart_truncate(buf): + to_be_read = buf.len + already_read = buf.tell() + + if already_read >= to_be_read: + old_bytes = buf.read() + buf.seek(0, 0) + buf.truncate() + buf.write(old_bytes) + buf.seek(0, 0) # We want to be at the beginning class FileWrapper(object): @@ -575,6 +633,12 @@ def len(self): def read(self, length=-1): return self.fd.read(length) + def tell(self): + return self.fd.tell() + + def seek(self, offset, whence=0): + return self.fd.seek(offset, whence) + class FileFromURLWrapper(object): """File from URL wrapper. diff --git a/tests/test_multipart_encoder.py b/tests/test_multipart_encoder.py index f864487c..edece9c4 100644 --- a/tests/test_multipart_encoder.py +++ b/tests/test_multipart_encoder.py @@ -1,12 +1,15 @@ # -*- coding: utf-8 -*- import unittest import io +import os +import tempfile import requests import pytest from requests_toolbelt.multipart.encoder import ( - CustomBytesIO, MultipartEncoder, FileFromURLWrapper, FileNotSupportedError) + CustomBytesIO, MultipartEncoder, FileFromURLWrapper, FileNotSupportedError, + smart_truncate) from requests_toolbelt._compat import filepost from . import get_betamax @@ -78,7 +81,7 @@ def test_truncates_intelligently(self): self.instance.write(b'abcdefghijklmnopqrstuvwxyzabcd') # 30 bytes assert self.instance.tell() == 30 self.instance.seek(-10, 2) - self.instance.smart_truncate() + smart_truncate(self.instance) assert self.instance.len == 10 assert self.instance.read() == b'uvwxyzabcd' assert self.instance.tell() == 10 @@ -128,6 +131,17 @@ def test_no_content_length_header(self): ) +EXPECTED = ( + b'--this-is-a-boundary\r\n' + b'Content-Disposition: form-data; name="field"\r\n\r\n' + b'value\r\n' + b'--this-is-a-boundary\r\n' + b'Content-Disposition: form-data; name="other_field"\r\n\r\n' + b'other_value\r\n' + b'--this-is-a-boundary--\r\n' +) + + class TestMultipartEncoder(unittest.TestCase): def setUp(self): self.parts = [('field', 'value'), ('other_field', 'other_value')] @@ -135,15 +149,7 @@ def setUp(self): self.instance = MultipartEncoder(self.parts, boundary=self.boundary) def test_to_string(self): - assert self.instance.to_string() == ( - '--this-is-a-boundary\r\n' - 'Content-Disposition: form-data; name="field"\r\n\r\n' - 'value\r\n' - '--this-is-a-boundary\r\n' - 'Content-Disposition: form-data; name="other_field"\r\n\r\n' - 'other_value\r\n' - '--this-is-a-boundary--\r\n' - ).encode() + assert self.instance.to_string() == EXPECTED def test_content_type(self): expected = 'multipart/form-data; boundary=this-is-a-boundary' @@ -202,6 +208,9 @@ def test_reads_file_from_url_wrapper(self): [('field', 'foo'), ('file', FileFromURLWrapper(url, session=s))]) assert m.read() is not None + with pytest.raises(OSError): + m.seek(0, 0) + def test_reads_open_file_objects_with_a_specified_filename(self): with open('setup.py', 'rb') as fd: m = MultipartEncoder( @@ -319,5 +328,52 @@ def test_no_parts(self): output = m.read().decode('utf-8') assert output == '----90967316f8404798963cce746a4f4ef9--\r\n' + def test_seeking(self): + field_data = self.parts[0][1].encode('utf-8') + + tmpfile = tempfile.TemporaryFile() + tmpfile.write(field_data) + tmpfile.seek(0) + parts = self.parts.copy() + parts[0] = (self.parts[0][0], tmpfile) + m = MultipartEncoder(parts, boundary=self.boundary) + + tmpfile = tempfile.TemporaryFile() + gunk = b"Some gunk at the beginning" + tmpfile.write(gunk) + tmpfile.write(field_data) + tmpfile.seek(len(gunk)) + parts = self.parts.copy() + parts[0] = (self.parts[0][0], tmpfile) + m2 = MultipartEncoder(parts, boundary=self.boundary) + + for instance in (self.instance, m, m2): + assert instance.tell() == 0 + assert instance.read() == EXPECTED + # Exhausted: + assert instance.read() == b'' + assert instance.seek(0) == 0 + assert instance.read() == EXPECTED + + def test_redirect(self): + """Verifies integration with requests.""" + tmpfile = tempfile.TemporaryFile() + tmpfile.write(b'from-file') + tmpfile.seek(0) + + m = MultipartEncoder([('field', 'foo'), ('myfile', tmpfile)]) + # Can't use betamax here - it responds too quickly and requests doesn't + # have time to start reading from the MultipartEncoder before the + # redirect response is returned - so the seek never happens. + resp = requests.post( + 'https://httpbin.org/redirect-to?status_code=307&url=/post', + data=m, headers={'Content-Type': m.content_type}, + timeout=10) + resp.raise_for_status() + print(resp.json()) + assert resp.json()['form']['myfile'] == 'from-file' + assert resp.json()['form']['field'] == 'foo' + + if __name__ == '__main__': unittest.main()