Skip to content

Commit 1722c93

Browse files
committed
Harden async upload edge cases
1 parent bb58520 commit 1722c93

5 files changed

Lines changed: 209 additions & 27 deletions

File tree

tests/test_async_client.py

Lines changed: 131 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
import io
21
import asyncio
2+
import io
33
import json
4+
import threading
45
from pathlib import Path
56
from unittest import IsolatedAsyncioTestCase, mock
67

7-
from aiohttp import web
8+
from aiohttp import payload, web
89

910
from transloadit.async_client import AsyncTransloadit
11+
from transloadit.async_request import _NonClosingUploadStream
1012
from transloadit.client import Transloadit
1113
from transloadit.response import Response
1214

@@ -232,10 +234,22 @@ def __init__(self, payload):
232234
self.closed = False
233235
self.payload = payload
234236

237+
def delete(self, url, **kwargs):
238+
self.calls.append((url, kwargs))
239+
return _FakeResponseContext(self.payload)
240+
241+
def get(self, url, **kwargs):
242+
self.calls.append((url, kwargs))
243+
return _FakeResponseContext(self.payload)
244+
235245
def post(self, url, **kwargs):
236246
self.calls.append((url, kwargs))
237247
return _FakeResponseContext(self.payload)
238248

249+
def put(self, url, **kwargs):
250+
self.calls.append((url, kwargs))
251+
return _FakeResponseContext(self.payload)
252+
239253
async def close(self):
240254
self.closed = True
241255

@@ -274,8 +288,8 @@ async def test_async_client_methods_and_context_manager(self):
274288
self.assertEqual(response.data["ok"], "ASSEMBLY_COMPLETED")
275289
self.assertEqual(response.data["assembly_id"], "abc123")
276290
self.assertEqual(response.status_code, 200)
277-
self.assertIs(type(response.headers), dict)
278291
self.assertEqual(response.headers["X-Async-Route"], "get_assembly")
292+
self.assertEqual(response.headers["x-async-route"], "get_assembly")
279293

280294
response = await client.list_assemblies()
281295
self.assertEqual(response.data["items"], [])
@@ -368,6 +382,28 @@ async def test_async_client_normalizes_service_and_rejects_missing_ids(self):
368382
with self.assertRaises(RuntimeError):
369383
await closed_client.get_assembly(assembly_id="abc123")
370384

385+
async def test_async_client_quotes_path_ids(self):
386+
session = _RecordingSession({"ok": "ASSEMBLY_COMPLETED"})
387+
client = AsyncTransloadit("key", "secret", service=self.server.base_url, session=session)
388+
389+
await client.get_assembly(assembly_id="assembly/with?chars")
390+
await client.cancel_assembly(assembly_id="cancel/with?chars")
391+
await client.get_template("template/with?chars")
392+
await client.update_template("update/with?chars", {"name": "foo"})
393+
await client.delete_template("delete/with?chars")
394+
395+
urls = [call[0] for call in session.calls]
396+
self.assertEqual(
397+
urls,
398+
[
399+
f"{self.server.base_url}/assemblies/assembly%2Fwith%3Fchars",
400+
f"{self.server.base_url}/assemblies/cancel%2Fwith%3Fchars",
401+
f"{self.server.base_url}/templates/template%2Fwith%3Fchars",
402+
f"{self.server.base_url}/templates/update%2Fwith%3Fchars",
403+
f"{self.server.base_url}/templates/delete%2Fwith%3Fchars",
404+
],
405+
)
406+
371407
async def test_async_client_close_reopens_owned_session(self):
372408
client = AsyncTransloadit("key", "secret", service=self.server.base_url)
373409

@@ -469,6 +505,34 @@ async def test_async_assembly_create_returns_plain_text_success_response(self):
469505
self.assertIs(response, plain_response)
470506
post_mock.assert_awaited_once()
471507

508+
async def test_async_assembly_resumable_plain_text_success_response_raises_before_tus_upload(self):
509+
calls = []
510+
511+
class _TusClient:
512+
def __init__(self, tus_url):
513+
calls.append(("client", tus_url))
514+
515+
def uploader(self, **kwargs):
516+
raise AssertionError("TUS upload should not start without upload URLs")
517+
518+
plain_response = Response(
519+
data="plain assembly response",
520+
status_code=200,
521+
headers={"X-Async-Route": "plain"},
522+
)
523+
524+
async with AsyncTransloadit("key", "secret", service=self.server.base_url) as client:
525+
assembly = client.new_assembly()
526+
assembly.add_file(io.BytesIO(b"payload"))
527+
528+
with mock.patch.object(client.request, "post", new=mock.AsyncMock(return_value=plain_response)) as post_mock:
529+
with mock.patch("transloadit.async_assembly.tus.TusClient", new=_TusClient):
530+
with self.assertRaises(RuntimeError):
531+
await assembly.create(resumable=True)
532+
533+
post_mock.assert_awaited_once()
534+
self.assertEqual(calls, [])
535+
472536
async def test_async_assembly_wait_raises_on_plain_text_poll_response(self):
473537
initial_response = Response(
474538
data={
@@ -1182,6 +1246,7 @@ async def test_async_assembly_wait_returns_last_poll_response_when_budget_exhaus
11821246

11831247
async def test_async_assembly_non_resumable_rate_limit_rewinds_files_for_retry(self):
11841248
reads = []
1249+
upload = io.BytesIO(b"payload")
11851250

11861251
async def fake_post(path, data=None, extra_data=None, files=None):
11871252
file_stream = files["file"]
@@ -1201,12 +1266,15 @@ async def fake_post(path, data=None, extra_data=None, files=None):
12011266
headers={},
12021267
)
12031268

1269+
async def fake_sleep(delay):
1270+
self.assertEqual(upload.tell(), 0)
1271+
12041272
async with AsyncTransloadit("key", "secret", service=self.server.base_url) as client:
12051273
assembly = client.new_assembly()
1206-
assembly.add_file(io.BytesIO(b"payload"))
1274+
assembly.add_file(upload)
12071275

12081276
with mock.patch.object(client.request, "post", new=mock.AsyncMock(side_effect=fake_post)):
1209-
with mock.patch("asyncio.sleep", new_callable=mock.AsyncMock) as sleep_mock:
1277+
with mock.patch("asyncio.sleep", new=mock.AsyncMock(side_effect=fake_sleep)) as sleep_mock:
12101278
response = await assembly.create(resumable=False, retries=2)
12111279

12121280
self.assertEqual(response.data["ok"], "ASSEMBLY_COMPLETED")
@@ -1238,6 +1306,47 @@ def seek(self, position, *args, **kwargs):
12381306

12391307
post_mock.assert_awaited_once()
12401308

1309+
async def test_async_assembly_rate_limit_ignores_malformed_error_values(self):
1310+
client = AsyncTransloadit("key", "secret", service=self.server.base_url)
1311+
assembly = client.new_assembly()
1312+
1313+
self.assertFalse(assembly._rate_limit_reached({"error": ["RATE_LIMIT_REACHED"]}))
1314+
self.assertFalse(assembly._rate_limit_reached({"error": {"code": "RATE_LIMIT_REACHED"}}))
1315+
1316+
async def test_async_tus_upload_cancellation_waits_for_thread_to_finish(self):
1317+
client = AsyncTransloadit("key", "secret", service=self.server.base_url)
1318+
assembly = client.new_assembly()
1319+
started = threading.Event()
1320+
release = threading.Event()
1321+
finished = threading.Event()
1322+
1323+
def blocking_upload(assembly_url, tus_url, retries):
1324+
started.set()
1325+
release.wait(timeout=5)
1326+
finished.set()
1327+
1328+
assembly._do_tus_upload = blocking_upload
1329+
upload_task = asyncio.create_task(
1330+
assembly._do_tus_upload_async(
1331+
f"{self.server.base_url}/assemblies/assembly-123",
1332+
f"{self.server.base_url}/uploads",
1333+
retries=1,
1334+
)
1335+
)
1336+
1337+
await asyncio.to_thread(started.wait, 5)
1338+
upload_task.cancel()
1339+
await asyncio.sleep(0.05)
1340+
1341+
self.assertFalse(upload_task.done())
1342+
self.assertFalse(finished.is_set())
1343+
1344+
release.set()
1345+
with self.assertRaises(asyncio.CancelledError):
1346+
await upload_task
1347+
1348+
self.assertTrue(finished.is_set())
1349+
12411350
async def test_async_request_uses_connect_and_read_timeouts_for_uploads(self):
12421351
session = _RecordingSession({"ok": "ASSEMBLY_COMPLETED"})
12431352
client = AsyncTransloadit("key", "secret", service=self.server.base_url, session=session)
@@ -1250,9 +1359,25 @@ async def test_async_request_uses_connect_and_read_timeouts_for_uploads(self):
12501359
timeout = session.calls[0][1]["timeout"]
12511360
self.assertIsNone(timeout.total)
12521361
self.assertEqual(timeout.sock_connect, 60)
1253-
self.assertIsNone(timeout.sock_read)
1362+
self.assertEqual(timeout.sock_read, 60)
12541363
self.assertEqual(session.calls[0][1]["data"]._fields[2][1]["Content-Type"], "image/jpeg")
12551364

1365+
async def test_async_request_upload_does_not_close_caller_stream(self):
1366+
fixture_path = Path(__file__).resolve().parents[1] / "LICENSE"
1367+
upload = fixture_path.open("rb")
1368+
1369+
try:
1370+
upload_payload = payload.get_payload(_NonClosingUploadStream(upload))
1371+
await upload_payload.close()
1372+
await asyncio.sleep(0.05)
1373+
1374+
self.assertFalse(upload.closed)
1375+
upload.seek(0)
1376+
self.assertEqual(upload.read(5), fixture_path.read_bytes()[:5])
1377+
finally:
1378+
if not upload.closed:
1379+
upload.close()
1380+
12561381
async def test_async_request_payload_preserves_custom_auth_constraints(self):
12571382
client = AsyncTransloadit("key", "secret", service=self.server.base_url)
12581383

transloadit/async_assembly.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,16 @@ def _do_tus_upload(self, assembly_url, tus_url, retries):
7878
).upload()
7979

8080
async def _do_tus_upload_async(self, assembly_url, tus_url, retries):
81-
await asyncio.to_thread(self._do_tus_upload, assembly_url, tus_url, retries)
81+
upload_task = asyncio.create_task(
82+
asyncio.to_thread(self._do_tus_upload, assembly_url, tus_url, retries)
83+
)
84+
try:
85+
await asyncio.shield(upload_task)
86+
except asyncio.CancelledError:
87+
try:
88+
await asyncio.shield(upload_task)
89+
finally:
90+
raise
8291

8392
async def create(self, wait=False, resumable=True, retries=3):
8493
"""
@@ -104,6 +113,8 @@ async def create(self, wait=False, resumable=True, retries=3):
104113
if response_data is None:
105114
if response.status_code >= 400:
106115
raise RuntimeError(f"Unexpected non-JSON response ({response.status_code}).")
116+
if resumable and self.files:
117+
raise RuntimeError("Resumable assembly response is missing upload URLs.")
107118
return response
108119

109120
if self._rate_limit_reached(response_data):
@@ -114,9 +125,9 @@ async def create(self, wait=False, resumable=True, retries=3):
114125
"Cannot retry non-resumable upload because these file streams are not seekable: "
115126
f"{missing}"
116127
)
117-
await asyncio.sleep(response_data.get("info", {}).get("retryIn", 1))
118128
if not resumable:
119129
self._rewind_files(file_positions)
130+
await asyncio.sleep(response_data.get("info", {}).get("retryIn", 1))
120131
retries -= 1
121132
continue
122133
return response
@@ -176,7 +187,8 @@ def _assembly_finished(self, response_data):
176187
return is_aborted or is_canceled or is_completed or (is_failed and not (is_fetch_rate_limit or is_submit_rate_limit))
177188

178189
def _rate_limit_reached(self, response_data):
179-
return response_data.get("error") in {
190+
error = response_data.get("error")
191+
return isinstance(error, str) and error in {
180192
"RATE_LIMIT_REACHED",
181193
"ASSEMBLY_STATUS_FETCHING_RATE_LIMIT_REACHED",
182194
}

transloadit/async_client.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import hmac
33
import time
44
from typing import List, Optional, Union
5-
from urllib.parse import quote_plus, urlencode
5+
from urllib.parse import quote, quote_plus, urlencode
66

77
from . import async_assembly, async_request, async_template
88

@@ -13,6 +13,10 @@ def _stringify_url_param(value: Union[str, int, float, bool]) -> str:
1313
return str(value)
1414

1515

16+
def _quote_path_segment(value: str) -> str:
17+
return quote(str(value), safe="")
18+
19+
1620
class AsyncTransloadit:
1721
"""
1822
Asynchronous client interface to the Transloadit API.
@@ -61,7 +65,7 @@ async def get_assembly(self, assembly_id: str = None, assembly_url: str = None):
6165
if not (assembly_id or assembly_url):
6266
raise ValueError("Either 'assembly_id' or 'assembly_url' cannot be None.")
6367

64-
url = assembly_url if assembly_url else f"/assemblies/{assembly_id}"
68+
url = assembly_url if assembly_url else f"/assemblies/{_quote_path_segment(assembly_id)}"
6569
return await self.request.get(url)
6670

6771
async def list_assemblies(self, params: dict = None):
@@ -77,14 +81,14 @@ async def cancel_assembly(self, assembly_id: str = None, assembly_url: str = Non
7781
if not (assembly_id or assembly_url):
7882
raise ValueError("Either 'assembly_id' or 'assembly_url' cannot be None.")
7983

80-
url = assembly_url if assembly_url else f"/assemblies/{assembly_id}"
84+
url = assembly_url if assembly_url else f"/assemblies/{_quote_path_segment(assembly_id)}"
8185
return await self.request.delete(url)
8286

8387
async def get_template(self, template_id: str):
8488
"""
8589
Get the template specified by the 'template_id'.
8690
"""
87-
return await self.request.get(f"/templates/{template_id}")
91+
return await self.request.get(f"/templates/{_quote_path_segment(template_id)}")
8892

8993
async def list_templates(self, params: Optional[dict] = None):
9094
"""
@@ -102,13 +106,13 @@ async def update_template(self, template_id: str, data: dict):
102106
"""
103107
Update the template specified by the 'template_id'.
104108
"""
105-
return await self.request.put(f"/templates/{template_id}", data=data)
109+
return await self.request.put(f"/templates/{_quote_path_segment(template_id)}", data=data)
106110

107111
async def delete_template(self, template_id: str):
108112
"""
109113
Delete the template specified by the 'template_id'.
110114
"""
111-
return await self.request.delete(f"/templates/{template_id}")
115+
return await self.request.delete(f"/templates/{_quote_path_segment(template_id)}")
112116

113117
async def get_bill(self, month: int, year: int):
114118
"""

0 commit comments

Comments
 (0)