Skip to content

Commit 3cb1bb4

Browse files
committed
Improve async coverage and retry safety
1 parent 83ea871 commit 3cb1bb4

4 files changed

Lines changed: 146 additions & 15 deletions

File tree

tests/test_async_client.py

Lines changed: 100 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -292,8 +292,7 @@ async def test_async_client_methods_and_context_manager(self):
292292
self.assertEqual(response.data["ok"], "TEMPLATE_CREATED")
293293
self.assertEqual(response.data["template_name"], "foo")
294294

295-
self.assertIsNotNone(client.request.session)
296-
self.assertTrue(client.request.session.closed)
295+
self.assertIsNone(client.request.session)
297296

298297
self.assertGreaterEqual(len(self.server.requests), 7)
299298
first_request = self.server.requests[0]
@@ -365,13 +364,28 @@ async def test_async_client_close_reopens_owned_session(self):
365364

366365
await client.close()
367366
self.assertTrue(first_session.closed)
367+
self.assertIsNone(client.request.session)
368368

369369
second_session = await client.request._ensure_session()
370370
self.assertIsNot(first_session, second_session)
371371
self.assertFalse(second_session.closed)
372372

373373
await client.close()
374374

375+
async def test_async_client_reopens_owned_session_when_session_is_closed(self):
376+
client = AsyncTransloadit("key", "secret", service=self.server.base_url)
377+
378+
first_session = await client.request._ensure_session()
379+
self.assertFalse(first_session.closed)
380+
381+
await first_session.close()
382+
reopened_session = await client.request._ensure_session()
383+
384+
self.assertIsNot(first_session, reopened_session)
385+
self.assertFalse(reopened_session.closed)
386+
387+
await client.close()
388+
375389
async def test_async_client_delete_template_get_bill_and_plain_text_fallback(self):
376390
async with AsyncTransloadit("key", "secret", service=self.server.base_url) as client:
377391
response = await client.delete_template("tpl-1")
@@ -408,6 +422,22 @@ async def test_async_assembly_create_raises_on_plain_text_error_response(self):
408422

409423
post_mock.assert_awaited_once()
410424

425+
async def test_async_assembly_create_returns_plain_text_success_response(self):
426+
plain_response = Response(
427+
data="plain assembly response",
428+
status_code=200,
429+
headers={"X-Async-Route": "plain"},
430+
)
431+
432+
async with AsyncTransloadit("key", "secret", service=self.server.base_url) as client:
433+
assembly = client.new_assembly()
434+
435+
with mock.patch.object(client.request, "post", new=mock.AsyncMock(return_value=plain_response)) as post_mock:
436+
response = await assembly.create(wait=False, resumable=False)
437+
438+
self.assertIs(response, plain_response)
439+
post_mock.assert_awaited_once()
440+
411441
async def test_async_assembly_wait_raises_on_plain_text_poll_response(self):
412442
initial_response = Response(
413443
data={
@@ -439,6 +469,37 @@ async def test_async_assembly_wait_raises_on_plain_text_poll_response(self):
439469
)
440470
sleep_mock.assert_awaited_once_with(0)
441471

472+
async def test_async_assembly_wait_returns_plain_text_poll_response(self):
473+
initial_response = Response(
474+
data={
475+
"ok": "ASSEMBLY_PROCESSING",
476+
"info": {"retryIn": 0},
477+
"assembly_ssl_url": f"{self.server.base_url}/assemblies/assembly-123",
478+
},
479+
status_code=200,
480+
headers={"X-Async-Route": "initial"},
481+
)
482+
plain_response = Response(
483+
data="plain assembly response",
484+
status_code=200,
485+
headers={"X-Async-Route": "plain"},
486+
)
487+
488+
async with AsyncTransloadit("key", "secret", service=self.server.base_url) as client:
489+
assembly = client.new_assembly()
490+
491+
with mock.patch.object(client.request, "post", new=mock.AsyncMock(return_value=initial_response)) as post_mock:
492+
with mock.patch.object(client, "get_assembly", new=mock.AsyncMock(return_value=plain_response)) as get_mock:
493+
with mock.patch("asyncio.sleep", new_callable=mock.AsyncMock) as sleep_mock:
494+
response = await assembly.create(wait=True, resumable=False)
495+
496+
self.assertIs(response, plain_response)
497+
post_mock.assert_awaited_once()
498+
get_mock.assert_awaited_once_with(
499+
assembly_url=f"{self.server.base_url}/assemblies/assembly-123"
500+
)
501+
sleep_mock.assert_awaited_once_with(0)
502+
442503
def test_async_signed_smart_cdn_url_matches_sync_and_rejects_bad_types(self):
443504
async_client = AsyncTransloadit("test-key", "test-secret")
444505
sync_client = Transloadit("test-key", "test-secret")
@@ -633,7 +694,7 @@ def upload(self_inner):
633694
async with AsyncTransloadit("key", "secret", service=self.server.base_url) as client:
634695
assembly = client.new_assembly()
635696
upload = io.BytesIO(b"payload")
636-
upload.name = "payload.bin"
697+
upload.name = b"payload.bin"
637698
assembly.add_file(upload)
638699

639700
rate_limited = Response(
@@ -730,6 +791,38 @@ def uploader(self, **kwargs):
730791
self.assertEqual(calls[0], ("client", f"{self.server.base_url}/uploads"))
731792
self.assertEqual(calls[1][0], "uploader")
732793

794+
async def test_async_assembly_non_resumable_rate_limit_raises_when_stream_cannot_be_snapshotted(self):
795+
class _NonSeekableStream(io.BytesIO):
796+
def tell(self):
797+
raise OSError("tell failed")
798+
799+
reads = []
800+
801+
async def fake_post(path, data=None, extra_data=None, files=None):
802+
file_stream = files["file"]
803+
reads.append(file_stream.read())
804+
return Response(
805+
data={
806+
"error": "RATE_LIMIT_REACHED",
807+
"info": {"retryIn": 0},
808+
},
809+
status_code=200,
810+
headers={},
811+
)
812+
813+
async with AsyncTransloadit("key", "secret", service=self.server.base_url) as client:
814+
assembly = client.new_assembly()
815+
assembly.add_file(_NonSeekableStream(b"payload"))
816+
817+
with mock.patch.object(client.request, "post", new=mock.AsyncMock(side_effect=fake_post)) as post_mock:
818+
with mock.patch("asyncio.sleep", new_callable=mock.AsyncMock) as sleep_mock:
819+
with self.assertRaises(RuntimeError):
820+
await assembly.create(resumable=False, retries=1)
821+
822+
self.assertEqual(reads, [b"payload"])
823+
post_mock.assert_awaited_once()
824+
sleep_mock.assert_not_awaited()
825+
733826
async def test_async_assembly_resumable_rate_limit_returns_response_without_upload_when_retries_exhausted(self):
734827
calls = []
735828

@@ -1169,9 +1262,10 @@ def test_async_assembly_helpers_cover_duplicate_names_and_rewind_edges(self):
11691262

11701263
first.read(1)
11711264
second.read(2)
1172-
positions = assembly._snapshot_file_positions()
1265+
positions, missing = assembly._snapshot_file_positions()
11731266
self.assertEqual(positions["file"], 1)
11741267
self.assertEqual(positions["file_1"], 2)
1268+
self.assertEqual(missing, [])
11751269

11761270
first.read(1)
11771271
second.read(1)
@@ -1181,8 +1275,9 @@ def test_async_assembly_helpers_cover_duplicate_names_and_rewind_edges(self):
11811275

11821276
broken = _BrokenStream()
11831277
assembly.files["broken"] = broken
1184-
positions = assembly._snapshot_file_positions()
1278+
positions, missing = assembly._snapshot_file_positions()
11851279
self.assertNotIn("broken", positions)
1280+
self.assertEqual(missing, ["broken"])
11861281

11871282
assembly._rewind_files({"missing": 4})
11881283
with self.assertRaises(RuntimeError):

tests/test_response.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import unittest
22
from unittest import mock
33

4-
from transloadit.response import Response
4+
from transloadit.response import Response, _MISSING
55

66

77
class ResponseTest(unittest.TestCase):
@@ -21,6 +21,19 @@ def test_response_data_is_assignable_and_eager_for_sync_responses(self):
2121
self.assertEqual(response.status_code, 200)
2222
self.assertEqual(response.headers, {"X-Test": "1"})
2323

24+
def test_response_lazily_rehydrates_data_when_missing(self):
25+
raw = mock.Mock()
26+
raw.json.return_value = {"ok": "lazy"}
27+
raw.status_code = 204
28+
raw.headers = {"X-Test": "1"}
29+
30+
response = Response()
31+
response._response = raw
32+
response._data = _MISSING
33+
34+
self.assertEqual(response.data, {"ok": "lazy"})
35+
raw.json.assert_called_once()
36+
2437
def test_response_supports_async_preloaded_values_and_empty_default(self):
2538
empty = Response()
2639
self.assertIsNone(empty.data)

transloadit/async_assembly.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,13 @@ def remove_file(self, field_name):
4343

4444
def _snapshot_file_positions(self):
4545
positions = {}
46+
missing = []
4647
for key, file_stream in self.files.items():
4748
try:
4849
positions[key] = file_stream.tell()
4950
except (AttributeError, OSError, ValueError):
50-
continue
51-
return positions
51+
missing.append(key)
52+
return positions, missing
5253

5354
def _rewind_files(self, positions):
5455
for key, position in positions.items():
@@ -84,7 +85,7 @@ async def create(self, wait=False, resumable=True, retries=3):
8485
Save/Submit the assembly for processing.
8586
"""
8687
data = self.get_options()
87-
file_positions = self._snapshot_file_positions()
88+
file_positions, missing_file_positions = self._snapshot_file_positions()
8889
tus_retries = retries
8990
poll_retries = retries
9091

@@ -107,6 +108,12 @@ async def create(self, wait=False, resumable=True, retries=3):
107108

108109
if self._rate_limit_reached(response_data):
109110
if retries:
111+
if not resumable and missing_file_positions:
112+
missing = ", ".join(repr(key) for key in missing_file_positions)
113+
raise RuntimeError(
114+
"Cannot retry non-resumable upload because these file streams are not seekable: "
115+
f"{missing}"
116+
)
110117
await asyncio.sleep(response_data.get("info", {}).get("retryIn", 1))
111118
if not resumable:
112119
self._rewind_files(file_positions)
@@ -170,4 +177,7 @@ def _assembly_finished(self, response_data):
170177
return is_aborted or is_canceled or is_completed or (is_failed and not (is_fetch_rate_limit or is_submit_rate_limit))
171178

172179
def _rate_limit_reached(self, response_data):
173-
return response_data.get("error") == "RATE_LIMIT_REACHED"
180+
return response_data.get("error") in {
181+
"RATE_LIMIT_REACHED",
182+
"ASSEMBLY_STATUS_FETCHING_RATE_LIMIT_REACHED",
183+
}

transloadit/async_request.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import hashlib
66
import hmac
77
import json
8+
from types import MappingProxyType
89
from datetime import datetime, timedelta, timezone
910

1011
import aiohttp
@@ -17,7 +18,10 @@
1718

1819
def _get_upload_filename(file_stream, fallback):
1920
name = getattr(file_stream, "name", None)
20-
if isinstance(name, (str, bytes, os.PathLike)):
21+
if isinstance(name, (bytes, os.PathLike)):
22+
name = os.fsdecode(name)
23+
24+
if isinstance(name, str):
2125
filename = os.path.basename(name)
2226
if filename:
2327
return filename
@@ -29,13 +33,13 @@ class AsyncRequest:
2933
Transloadit tailored asynchronous HTTP request object.
3034
"""
3135

32-
HEADERS = {"Transloadit-Client": "python-sdk:" + __version__}
36+
HEADERS = MappingProxyType({"Transloadit-Client": "python-sdk:" + __version__})
3337

3438
def __init__(self, transloadit, session=None):
3539
self.transloadit = transloadit
3640
self._session = session
3741
self._owns_session = session is None
38-
self._session_lock = asyncio.Lock()
42+
self._session_lock = None
3943

4044
@property
4145
def session(self):
@@ -44,8 +48,16 @@ def session(self):
4448
def _headers(self):
4549
return dict(self.HEADERS)
4650

51+
def _get_session_lock(self):
52+
if self._session_lock is None:
53+
# Create the lock lazily so the client can be instantiated before the loop starts.
54+
self._session_lock = asyncio.Lock()
55+
return self._session_lock
56+
4757
async def _ensure_session(self):
48-
async with self._session_lock:
58+
if self._session is not None and not self._session.closed:
59+
return self._session
60+
async with self._get_session_lock():
4961
if self._session is None:
5062
self._session = aiohttp.ClientSession()
5163
self._owns_session = True
@@ -57,9 +69,10 @@ async def _ensure_session(self):
5769
return self._session
5870

5971
async def aclose(self):
60-
async with self._session_lock:
72+
async with self._get_session_lock():
6173
if self._session is not None and not self._session.closed and self._owns_session:
6274
await self._session.close()
75+
self._session = None
6376

6477
def _timeout(self, files=False):
6578
return aiohttp.ClientTimeout(

0 commit comments

Comments
 (0)