Skip to content

Commit f8d3482

Browse files
committed
Refine async retries and upload metadata
1 parent 559b875 commit f8d3482

3 files changed

Lines changed: 127 additions & 50 deletions

File tree

tests/test_async_client.py

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,7 @@ def __init__(self, tus_url):
628628
calls.append(("client", tus_url))
629629

630630
def uploader(self, **kwargs):
631-
calls.append(("uploader", kwargs["metadata"]))
631+
calls.append(("uploader", kwargs["metadata"], kwargs["retries"]))
632632

633633
class _Uploader:
634634
def upload(self_inner):
@@ -673,7 +673,7 @@ def upload(self_inner):
673673
self.assertEqual(post_mock.await_count, 2)
674674
self.assertEqual(to_thread_mock.await_count, 1)
675675
self.assertEqual(calls[0], ("client", f"{self.server.base_url}/uploads"))
676-
self.assertEqual(calls[1][0], "uploader")
676+
self.assertEqual(calls[1], ("uploader", {"assembly_url": f"{self.server.base_url}/assemblies/assembly-123", "fieldname": "file", "filename": "payload.bin"}, 2))
677677

678678
async def test_async_assembly_resumable_rate_limit_skips_rewind_before_retrying(self):
679679
calls = []
@@ -868,7 +868,7 @@ def uploader(self, **kwargs):
868868
async with AsyncTransloadit("key", "secret", service=self.server.base_url) as client:
869869
assembly = client.new_assembly()
870870
upload = io.BytesIO(b"payload")
871-
upload.name = None
871+
upload.name = 123
872872
assembly.add_file(upload, "explicit_field")
873873

874874
with mock.patch(
@@ -953,6 +953,46 @@ async def test_async_assembly_wait_retries_after_polling_rate_limit(self):
953953
)
954954
self.assertEqual(sleep_mock.await_args_list, [mock.call(0), mock.call(0)])
955955

956+
async def test_async_assembly_wait_returns_last_poll_response_when_budget_exhausted(self):
957+
async with AsyncTransloadit("key", "secret", service=self.server.base_url) as client:
958+
assembly = client.new_assembly()
959+
960+
initial = Response(
961+
data={
962+
"ok": "ASSEMBLY_PROCESSING",
963+
"info": {"retryIn": 0},
964+
"assembly_ssl_url": f"{self.server.base_url}/assemblies/assembly-123",
965+
},
966+
status_code=200,
967+
headers={"X-Async-Route": "initial"},
968+
)
969+
rate_limited = Response(
970+
data={
971+
"ok": "ASSEMBLY_PROCESSING",
972+
"error": "RATE_LIMIT_REACHED",
973+
"info": {"retryIn": 0},
974+
"assembly_ssl_url": f"{self.server.base_url}/assemblies/assembly-123",
975+
},
976+
status_code=200,
977+
headers={"X-Async-Route": "rate_limited"},
978+
)
979+
980+
with mock.patch.object(client.request, "post", new=mock.AsyncMock(return_value=initial)) as post_mock:
981+
with mock.patch.object(
982+
client,
983+
"get_assembly",
984+
new=mock.AsyncMock(return_value=rate_limited),
985+
) as get_mock:
986+
with mock.patch("asyncio.sleep", new_callable=mock.AsyncMock) as sleep_mock:
987+
response = await assembly.create(wait=True, resumable=False, retries=1)
988+
989+
self.assertEqual(response.data["error"], "RATE_LIMIT_REACHED")
990+
post_mock.assert_awaited_once()
991+
get_mock.assert_awaited_once_with(
992+
assembly_url=f"{self.server.base_url}/assemblies/assembly-123"
993+
)
994+
self.assertEqual(sleep_mock.await_args_list, [mock.call(0)])
995+
956996
async def test_async_assembly_non_resumable_rate_limit_rewinds_files_for_retry(self):
957997
reads = []
958998

@@ -1015,7 +1055,7 @@ async def test_async_request_uses_connect_and_read_timeouts_for_uploads(self):
10151055
session = _RecordingSession({"ok": "ASSEMBLY_COMPLETED"})
10161056
client = AsyncTransloadit("key", "secret", service=self.server.base_url, session=session)
10171057
upload = io.BytesIO(b"payload")
1018-
upload.name = None
1058+
upload.name = "clip.jpg"
10191059

10201060
response = await client.request.post("/assemblies", data={"foo": "bar"}, files={"file": upload})
10211061

@@ -1024,6 +1064,18 @@ async def test_async_request_uses_connect_and_read_timeouts_for_uploads(self):
10241064
self.assertIsNone(timeout.total)
10251065
self.assertEqual(timeout.sock_connect, 60)
10261066
self.assertIsNone(timeout.sock_read)
1067+
self.assertEqual(session.calls[0][1]["data"]._fields[2][1]["Content-Type"], "image/jpeg")
1068+
1069+
async def test_async_request_uses_filename_fallback_for_trailing_slash_stream_name(self):
1070+
session = _RecordingSession({"ok": "ASSEMBLY_COMPLETED"})
1071+
client = AsyncTransloadit("key", "secret", service=self.server.base_url, session=session)
1072+
upload = io.BytesIO(b"payload")
1073+
upload.name = "/tmp/"
1074+
1075+
response = await client.request.post("/assemblies", data={"foo": "bar"}, files={"file": upload})
1076+
1077+
self.assertEqual(response.data["ok"], "ASSEMBLY_COMPLETED")
1078+
self.assertEqual(session.calls[0][1]["data"]._fields[2][0]["filename"], "file")
10271079

10281080
async def test_async_resumable_upload_uses_to_thread(self):
10291081
calls = []

transloadit/async_assembly.py

Lines changed: 58 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import asyncio
2-
import os
32

43
from tusclient import client as tus
54

65
from . import optionbuilder
6+
from .async_request import _get_upload_filename
77

88

99
class AsyncAssembly(optionbuilder.OptionBuilder):
@@ -63,11 +63,11 @@ def _rewind_files(self, positions):
6363
def _do_tus_upload(self, assembly_url, tus_url, retries):
6464
tus_client = tus.TusClient(tus_url)
6565
for key, file_stream in self.files.items():
66-
filename = getattr(file_stream, "name", None) or key
66+
filename = _get_upload_filename(file_stream, key)
6767
metadata = {
6868
"assembly_url": assembly_url,
6969
"fieldname": key,
70-
"filename": os.path.basename(filename) or key,
70+
"filename": filename,
7171
}
7272
tus_client.uploader(
7373
file_stream=file_stream,
@@ -85,55 +85,69 @@ async def create(self, wait=False, resumable=True, retries=3):
8585
"""
8686
data = self.get_options()
8787
file_positions = self._snapshot_file_positions()
88-
if resumable:
89-
extra_data = {"tus_num_expected_upload_files": len(self.files)}
90-
response = await self.transloadit.request.post(
91-
"/assemblies", extra_data=extra_data, data=data
92-
)
93-
else:
94-
response = await self.transloadit.request.post(
95-
"/assemblies", data=data, files=self.files
96-
)
97-
98-
response_data = self._response_data(response)
99-
if response_data is None:
100-
return response
88+
tus_retries = retries
89+
poll_retries = retries
90+
91+
while True:
92+
if resumable:
93+
extra_data = {"tus_num_expected_upload_files": len(self.files)} if self.files else None
94+
response = await self.transloadit.request.post(
95+
"/assemblies", extra_data=extra_data, data=data
96+
)
97+
else:
98+
response = await self.transloadit.request.post(
99+
"/assemblies", data=data, files=self.files
100+
)
101101

102-
if self._rate_limit_reached(response_data):
103-
if retries:
104-
await asyncio.sleep(response_data.get("info", {}).get("retryIn", 1))
105-
if not resumable:
106-
self._rewind_files(file_positions)
107-
return await self.create(wait, resumable, retries - 1)
108-
return response
102+
response_data = self._response_data(response)
103+
if response_data is None:
104+
return response
109105

110-
error = response_data.get("error")
111-
assembly_url = response_data.get("assembly_ssl_url")
112-
tus_url = response_data.get("tus_url")
106+
if self._rate_limit_reached(response_data):
107+
if retries:
108+
await asyncio.sleep(response_data.get("info", {}).get("retryIn", 1))
109+
if not resumable:
110+
self._rewind_files(file_positions)
111+
retries -= 1
112+
continue
113+
return response
113114

114-
if error is not None:
115-
return response
115+
error = response_data.get("error")
116+
assembly_url = response_data.get("assembly_ssl_url")
117+
tus_url = response_data.get("tus_url")
116118

117-
if resumable and self.files:
118-
if not assembly_url or not tus_url:
119+
if error is not None:
119120
return response
120-
await self._do_tus_upload_async(assembly_url, tus_url, retries)
121121

122-
if wait:
123-
if not assembly_url:
124-
return response
125-
while not self._assembly_finished(response_data):
126-
sleep_time = response_data.get("info", {}).get("retryIn", 1)
127-
await asyncio.sleep(sleep_time)
128-
response = await self.transloadit.get_assembly(
129-
assembly_url=assembly_url or response_data.get("assembly_ssl_url")
130-
)
131-
response_data = self._response_data(response)
132-
if response_data is None:
122+
if resumable and self.files:
123+
if not assembly_url or not tus_url:
133124
return response
134-
assembly_url = response_data.get("assembly_ssl_url") or assembly_url
125+
await self._do_tus_upload_async(assembly_url, tus_url, tus_retries)
135126

136-
return response
127+
if wait:
128+
if not assembly_url:
129+
return response
130+
131+
poll_response = response
132+
poll_data = response_data
133+
remaining_polls = poll_retries
134+
while not self._assembly_finished(poll_data):
135+
if remaining_polls <= 0:
136+
return poll_response
137+
sleep_time = poll_data.get("info", {}).get("retryIn", 1)
138+
await asyncio.sleep(sleep_time)
139+
poll_response = await self.transloadit.get_assembly(
140+
assembly_url=assembly_url or poll_data.get("assembly_ssl_url")
141+
)
142+
poll_data = self._response_data(poll_response)
143+
if poll_data is None:
144+
return poll_response
145+
assembly_url = poll_data.get("assembly_ssl_url") or assembly_url
146+
remaining_polls -= 1
147+
148+
return poll_response
149+
150+
return response
137151

138152
def _response_data(self, response):
139153
data = response.data

transloadit/async_request.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import mimetypes
23
import os
34
import copy
45
import hashlib
@@ -14,6 +15,15 @@
1415
TIMEOUT = 60
1516

1617

18+
def _get_upload_filename(file_stream, fallback):
19+
name = getattr(file_stream, "name", None)
20+
if isinstance(name, (str, bytes, os.PathLike)):
21+
filename = os.path.basename(name)
22+
if filename:
23+
return filename
24+
return fallback
25+
26+
1727
class AsyncRequest:
1828
"""
1929
Transloadit tailored asynchronous HTTP request object.
@@ -96,8 +106,9 @@ async def post(self, path, data=None, extra_data=None, files=None):
96106
form.add_field(key, value)
97107

98108
for key, file_stream in files.items():
99-
filename = os.path.basename(getattr(file_stream, "name", None) or key) or key
100-
form.add_field(key, file_stream, filename=filename)
109+
filename = _get_upload_filename(file_stream, key)
110+
content_type = mimetypes.guess_type(filename)[0] or "application/octet-stream"
111+
form.add_field(key, file_stream, filename=filename, content_type=content_type)
101112
payload = form
102113
else:
103114
payload = self._normalize_payload(data)

0 commit comments

Comments
 (0)