Skip to content

Commit 924fffc

Browse files
committed
Address async council review findings
1 parent 3cb1bb4 commit 924fffc

3 files changed

Lines changed: 77 additions & 4 deletions

File tree

tests/test_async_client.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,17 @@ async def text(self):
215215
return json.dumps(self.payload)
216216

217217

218+
class _UndecodableResponse:
219+
async def json(self, **kwargs):
220+
raise UnicodeDecodeError("utf-8", b"\xff", 0, 1, "invalid start byte")
221+
222+
async def text(self):
223+
raise UnicodeDecodeError("utf-8", b"\xff", 0, 1, "invalid start byte")
224+
225+
async def read(self):
226+
return b"\xff"
227+
228+
218229
class _RecordingSession:
219230
def __init__(self, payload):
220231
self.calls = []
@@ -406,6 +417,13 @@ async def test_async_client_delete_template_get_bill_and_plain_text_fallback(sel
406417
self.assertEqual(response.status_code, 200)
407418
self.assertEqual(response.headers["X-Async-Route"], "get_assembly_plain")
408419

420+
async def test_async_request_falls_back_to_bytes_when_text_decode_fails(self):
421+
client = AsyncTransloadit("key", "secret", service=self.server.base_url)
422+
423+
data = await client.request._read_response_data(_UndecodableResponse())
424+
425+
self.assertEqual(data, b"\xff")
426+
409427
async def test_async_assembly_create_raises_on_plain_text_error_response(self):
410428
plain_response = Response(
411429
data="plain assembly response",
@@ -670,6 +688,10 @@ def __init__(self, tus_url):
670688

671689
self.assertEqual(response.data["ok"], "ASSEMBLY_COMPLETED")
672690
post_mock.assert_awaited_once()
691+
self.assertEqual(
692+
post_mock.await_args.kwargs["extra_data"],
693+
{"tus_num_expected_upload_files": 0},
694+
)
673695
get_mock.assert_awaited_once_with(
674696
assembly_url=f"{self.server.base_url}/assemblies/assembly-123"
675697
)
@@ -1040,6 +1062,55 @@ async def test_async_assembly_wait_retries_after_polling_rate_limit(self):
10401062
)
10411063
self.assertEqual(sleep_mock.await_args_list, [mock.call(0), mock.call(0)])
10421064

1065+
async def test_async_assembly_wait_does_not_follow_poll_response_assembly_url(self):
1066+
initial_url = f"{self.server.base_url}/assemblies/assembly-123"
1067+
1068+
async with AsyncTransloadit("key", "secret", service=self.server.base_url) as client:
1069+
assembly = client.new_assembly()
1070+
1071+
initial = Response(
1072+
data={
1073+
"ok": "ASSEMBLY_PROCESSING",
1074+
"info": {"retryIn": 0},
1075+
"assembly_ssl_url": initial_url,
1076+
},
1077+
status_code=200,
1078+
headers={},
1079+
)
1080+
malicious_poll = Response(
1081+
data={
1082+
"ok": "ASSEMBLY_PROCESSING",
1083+
"error": "ASSEMBLY_STATUS_FETCHING_RATE_LIMIT_REACHED",
1084+
"info": {"retryIn": 0},
1085+
"assembly_ssl_url": "https://example.invalid/assemblies/evil",
1086+
},
1087+
status_code=200,
1088+
headers={},
1089+
)
1090+
completed = Response(
1091+
data={"ok": "ASSEMBLY_COMPLETED", "assembly_id": "assembly-123"},
1092+
status_code=200,
1093+
headers={},
1094+
)
1095+
1096+
with mock.patch.object(client.request, "post", new=mock.AsyncMock(return_value=initial)):
1097+
with mock.patch.object(
1098+
client,
1099+
"get_assembly",
1100+
new=mock.AsyncMock(side_effect=[malicious_poll, completed]),
1101+
) as get_mock:
1102+
with mock.patch("asyncio.sleep", new_callable=mock.AsyncMock):
1103+
response = await assembly.create(wait=True, resumable=False, retries=2)
1104+
1105+
self.assertEqual(response.data["ok"], "ASSEMBLY_COMPLETED")
1106+
self.assertEqual(
1107+
get_mock.await_args_list,
1108+
[
1109+
mock.call(assembly_url=initial_url),
1110+
mock.call(assembly_url=initial_url),
1111+
],
1112+
)
1113+
10431114
async def test_async_assembly_wait_returns_last_poll_response_when_budget_exhausted(self):
10441115
async with AsyncTransloadit("key", "secret", service=self.server.base_url) as client:
10451116
assembly = client.new_assembly()

transloadit/async_assembly.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ async def create(self, wait=False, resumable=True, retries=3):
9191

9292
while True:
9393
if resumable:
94-
extra_data = {"tus_num_expected_upload_files": len(self.files)} if self.files else None
94+
extra_data = {"tus_num_expected_upload_files": len(self.files)}
9595
response = await self.transloadit.request.post(
9696
"/assemblies", extra_data=extra_data, data=data
9797
)
@@ -148,14 +148,13 @@ async def create(self, wait=False, resumable=True, retries=3):
148148
sleep_time = poll_data.get("info", {}).get("retryIn", 1)
149149
await asyncio.sleep(sleep_time)
150150
poll_response = await self.transloadit.get_assembly(
151-
assembly_url=assembly_url or poll_data.get("assembly_ssl_url")
151+
assembly_url=assembly_url
152152
)
153153
poll_data = self._response_data(poll_response)
154154
if poll_data is None:
155155
if poll_response.status_code >= 400:
156156
raise RuntimeError(f"Unexpected non-JSON response ({poll_response.status_code}).")
157157
return poll_response
158-
assembly_url = poll_data.get("assembly_ssl_url") or assembly_url
159158

160159
return poll_response
161160

transloadit/async_request.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,10 @@ async def _read_response_data(self, response):
9696
try:
9797
return await response.json(content_type=None)
9898
except (aiohttp.ContentTypeError, json.JSONDecodeError, UnicodeDecodeError):
99-
return await response.text()
99+
try:
100+
return await response.text()
101+
except UnicodeDecodeError:
102+
return await response.read()
100103

101104
async def get(self, path, params=None):
102105
"""

0 commit comments

Comments
 (0)