Skip to content

Commit 1a188fe

Browse files
committed
multipart file upload support for async file
1 parent 17c403c commit 1a188fe

5 files changed

Lines changed: 179 additions & 29 deletions

File tree

httpx/_compat.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import sys
2+
3+
if sys.version_info >= (3, 10):
4+
from contextlib import aclosing
5+
else:
6+
from contextlib import asynccontextmanager
7+
from typing import Any, AsyncIterator, Awaitable, Protocol, TypeVar
8+
9+
class _SupportsAclose(Protocol):
10+
def aclose(self) -> Awaitable[object]: ...
11+
12+
_SupportsAcloseT = TypeVar("_SupportsAcloseT", bound=_SupportsAclose)
13+
14+
@asynccontextmanager
15+
async def aclosing(thing: _SupportsAcloseT) -> AsyncIterator[Any]:
16+
try:
17+
yield thing
18+
finally:
19+
await thing.aclose()
20+
21+
22+
if sys.version_info >= (3, 13):
23+
from typing import TypeIs
24+
else:
25+
from typing_extensions import TypeIs
26+
27+
28+
__all__ = ["aclosing", "TypeIs"]

httpx/_multipart.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@
77
import typing
88
from pathlib import Path
99

10+
from ._compat import aclosing
1011
from ._types import (
1112
AsyncByteStream,
1213
FileContent,
1314
FileTypes,
1415
RequestData,
1516
RequestFiles,
1617
SyncByteStream,
18+
is_async_readable_binary_file,
1719
)
1820
from ._utils import (
1921
peek_filelike_length,
@@ -201,6 +203,11 @@ def render_headers(self) -> bytes:
201203
return self._headers
202204

203205
def render_data(self) -> typing.Iterator[bytes]:
206+
if is_async_readable_binary_file(self.file):
207+
raise TypeError(
208+
"Invalid type for file. AsyncReadableBinaryFile is not supported."
209+
)
210+
204211
if isinstance(self.file, (str, bytes)):
205212
yield to_bytes(self.file)
206213
return
@@ -216,10 +223,27 @@ def render_data(self) -> typing.Iterator[bytes]:
216223
yield to_bytes(chunk)
217224
chunk = self.file.read(self.CHUNK_SIZE)
218225

226+
async def arender_data(self) -> typing.AsyncGenerator[bytes]:
227+
if not is_async_readable_binary_file(self.file):
228+
for chunk in self.render_data():
229+
yield chunk
230+
return
231+
await self.file.seek(0)
232+
chunk = await self.file.read(self.CHUNK_SIZE)
233+
while chunk:
234+
yield to_bytes(chunk)
235+
chunk = await self.file.read(self.CHUNK_SIZE)
236+
219237
def render(self) -> typing.Iterator[bytes]:
220238
yield self.render_headers()
221239
yield from self.render_data()
222240

241+
async def arender(self) -> typing.AsyncGenerator[bytes]:
242+
yield self.render_headers()
243+
async with aclosing(self.arender_data()) as data:
244+
async for chunk in data:
245+
yield chunk
246+
223247

224248
class MultipartStream(SyncByteStream, AsyncByteStream):
225249
"""
@@ -262,6 +286,19 @@ def iter_chunks(self) -> typing.Iterator[bytes]:
262286
yield b"\r\n"
263287
yield b"--%s--\r\n" % self.boundary
264288

289+
async def aiter_chunks(self) -> typing.AsyncGenerator[bytes]:
290+
for field in self.fields:
291+
yield b"--%s\r\n" % self.boundary
292+
if isinstance(field, FileField):
293+
async with aclosing(field.arender()) as data:
294+
async for chunk in data:
295+
yield chunk
296+
else:
297+
for chunk in field.render():
298+
yield chunk
299+
yield b"\r\n"
300+
yield b"--%s--\r\n" % self.boundary
301+
265302
def get_content_length(self) -> int | None:
266303
"""
267304
Return the length of the multipart encoded content, or `None` if
@@ -296,5 +333,6 @@ def __iter__(self) -> typing.Iterator[bytes]:
296333
yield chunk
297334

298335
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
299-
for chunk in self.iter_chunks():
300-
yield chunk
336+
async with aclosing(self.aiter_chunks()) as data:
337+
async for chunk in data:
338+
yield chunk

httpx/_types.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
IO,
99
TYPE_CHECKING,
1010
Any,
11+
AnyStr,
1112
AsyncIterable,
1213
AsyncIterator,
1314
Callable,
@@ -23,7 +24,7 @@
2324
Union,
2425
)
2526

26-
from typing_extensions import TypeGuard
27+
from ._compat import TypeIs
2728

2829
if TYPE_CHECKING: # pragma: no cover
2930
from ._auth import Auth # noqa: F401
@@ -75,7 +76,18 @@
7576

7677
RequestData = Mapping[str, Any]
7778

78-
FileContent = Union[IO[bytes], bytes, str]
79+
80+
class AsyncReadableBinaryFile(Protocol):
81+
async def __aiter__(self) -> AsyncIterator[AnyStr]: ...
82+
83+
async def read(self, size: int = -1) -> AnyStr: ...
84+
85+
def fileno(self) -> int: ...
86+
87+
async def seek(self, offset: int, whence: int | None = ...) -> int: ...
88+
89+
90+
FileContent = Union[IO[bytes], bytes, str, AsyncReadableBinaryFile]
7991
FileTypes = Union[
8092
# file (or bytes)
8193
FileContent,
@@ -118,22 +130,14 @@ async def aclose(self) -> None:
118130
pass
119131

120132

121-
class AsyncReadableBinaryFile(Protocol):
122-
async def __aiter__(self) -> AsyncIterator[bytes]: ...
123-
124-
async def read(self, size: int = -1) -> bytes: ...
125-
126-
def fileno(self) -> int: ...
127-
128-
129-
def is_async_readable_binary_file(fp: Any) -> TypeGuard[AsyncReadableBinaryFile]:
133+
def is_async_readable_binary_file(fp: Any) -> TypeIs[AsyncReadableBinaryFile]:
130134
return (
131135
isinstance(fp, AsyncIterable)
132136
and hasattr(fp, "read")
133137
and inspect.iscoroutinefunction(fp.read)
134138
and hasattr(fp, "fileno")
135139
and callable(fp.fileno)
136140
and not inspect.iscoroutinefunction(fp.fileno)
137-
and hasattr(fp, "mode")
138-
and "b" in fp.mode
141+
and hasattr(fp, "seek")
142+
and inspect.iscoroutinefunction(fp.seek)
139143
)

tests/test_content.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -526,16 +526,13 @@ def test_allow_nan_false():
526526
@pytest.mark.parametrize("client_method", ["put", "post"])
527527
@pytest.mark.anyio
528528
async def test_chunked_async_file_content(
529-
tmp_path, anyio_backend, monkeypatch, client_method
529+
tmp_path, anyio_backend, monkeypatch, client_method, server
530530
):
531531
total_chunks = 3
532-
533-
def echo_request_content(request: httpx.Request) -> httpx.Response:
534-
return httpx.Response(200, content=request.content)
535-
536532
content_bytes = b"".join([b"a" * AsyncIteratorByteStream.CHUNK_SIZE] * total_chunks)
537533
to_upload = tmp_path / "upload.txt"
538534
to_upload.write_bytes(content_bytes)
535+
url = server.url.copy_with(path="/echo_body")
539536

540537
async def checks(
541538
client: httpx.AsyncClient, async_file: AsyncReadableBinaryFile
@@ -557,9 +554,7 @@ def mock_fileno(*args):
557554

558555
monkeypatch.setattr(async_file, "read", mock_read)
559556
monkeypatch.setattr(async_file, "fileno", mock_fileno)
560-
response = await getattr(client, client_method)(
561-
url="http://127.0.0.1:8000/", content=async_file
562-
)
557+
response = await getattr(client, client_method)(url=url, content=async_file)
563558
assert response.status_code == 200
564559
assert response.content == content_bytes
565560
assert response.request.headers["Content-Length"] == str(len(content_bytes))
@@ -570,19 +565,15 @@ def mock_fileno(*args):
570565
await anyio.open_file(to_upload, mode="rb")
571566
if anyio_backend != "trio"
572567
else await trio.open_file(to_upload, mode="rb") as async_file,
573-
httpx.AsyncClient(
574-
transport=httpx.MockTransport(echo_request_content)
575-
) as client,
568+
httpx.AsyncClient() as client,
576569
):
577570
assert is_async_readable_binary_file(async_file)
578571
await checks(client, async_file)
579572

580573
if anyio_backend != "trio": # aiofiles doesn't work with trio
581574
async with (
582575
aiofiles.open(to_upload, mode="rb") as aio_file,
583-
httpx.AsyncClient(
584-
transport=httpx.MockTransport(echo_request_content)
585-
) as client,
576+
httpx.AsyncClient() as client,
586577
):
587578
assert is_async_readable_binary_file(aio_file)
588579
await checks(client, aio_file)

tests/test_multipart.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,13 @@
44
import tempfile
55
import typing
66

7+
import anyio
78
import pytest
9+
import trio
810

911
import httpx
12+
from httpx._multipart import FileField
13+
from httpx._types import AsyncReadableBinaryFile, is_async_readable_binary_file
1014

1115

1216
def echo_request_content(request: httpx.Request) -> httpx.Response:
@@ -467,3 +471,88 @@ def test_unicode_with_control_character(self):
467471
files = {"upload": (filename, b"<file content>")}
468472
request = httpx.Request("GET", "https://www.example.com", files=files)
469473
assert expected in request.read()
474+
475+
476+
@pytest.mark.anyio
477+
async def test_chunked_async_file_multipart(
478+
tmp_path, anyio_backend, monkeypatch, server
479+
):
480+
total_chunks = 3
481+
482+
content_bytes = b"".join([b"a" * FileField.CHUNK_SIZE] * total_chunks)
483+
to_upload = tmp_path / "upload.txt"
484+
to_upload.write_bytes(content_bytes)
485+
url = server.url.copy_with(path="/echo_body")
486+
487+
async def checks(
488+
client: httpx.AsyncClient, async_file: AsyncReadableBinaryFile
489+
) -> None:
490+
read_called = 0
491+
fileno_called = False
492+
original_read = async_file.read
493+
original_fileno = async_file.fileno
494+
495+
async def mock_read(*args, **kwargs):
496+
nonlocal read_called
497+
read_called += 1
498+
return await original_read(*args, **kwargs)
499+
500+
def mock_fileno(*args):
501+
nonlocal fileno_called
502+
fileno_called = True
503+
return original_fileno(*args)
504+
505+
monkeypatch.setattr(async_file, "read", mock_read)
506+
monkeypatch.setattr(async_file, "fileno", mock_fileno)
507+
response = await client.post(url=url, files={"file": async_file})
508+
assert response.status_code == 200
509+
boundary = response.request.headers["Content-Type"].split("boundary=")[-1]
510+
boundary_bytes = boundary.encode("ascii")
511+
pre_content = b"".join(
512+
[
513+
b"--" + boundary_bytes + b"\r\n",
514+
b'Content-Disposition: form-data; name="file"; '
515+
b'filename="upload.txt"\r\n',
516+
b"Content-Type: text/plain\r\n",
517+
b"\r\n",
518+
]
519+
)
520+
post_content = b"".join(
521+
[
522+
b"\r\n",
523+
b"--" + boundary_bytes + b"--\r\n",
524+
]
525+
)
526+
assert response.content == b"".join(
527+
[
528+
pre_content,
529+
content_bytes,
530+
post_content,
531+
]
532+
)
533+
assert response.request.headers["Content-Length"] == str(
534+
len(pre_content) + len(post_content) + len(content_bytes)
535+
)
536+
assert read_called == total_chunks + 1
537+
assert fileno_called
538+
539+
async with (
540+
await anyio.open_file(to_upload, mode="rb")
541+
if anyio_backend != "trio"
542+
else await trio.open_file(to_upload, mode="rb") as async_file,
543+
httpx.AsyncClient() as client,
544+
):
545+
assert is_async_readable_binary_file(async_file)
546+
547+
await checks(client, async_file)
548+
549+
async with (
550+
await anyio.open_file(to_upload, mode="rb")
551+
if anyio_backend != "trio"
552+
else await trio.open_file(to_upload, mode="rb") as async_file,
553+
):
554+
with (
555+
httpx.Client() as sync_client,
556+
pytest.raises(TypeError, match="AsyncReadableBinaryFile is not supported"),
557+
):
558+
sync_client.post(url, files={"file": async_file})

0 commit comments

Comments
 (0)