Skip to content

Commit def57f6

Browse files
committed
Harden async retry rewinds
1 parent fb1bcfd commit def57f6

2 files changed

Lines changed: 93 additions & 4 deletions

File tree

tests/test_async_client.py

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,67 @@ def upload(self_inner):
635635
self.assertEqual(calls[0], ("client", f"{self.server.base_url}/uploads"))
636636
self.assertEqual(calls[1][0], "uploader")
637637

638+
async def test_async_assembly_resumable_rate_limit_skips_rewind_before_retrying(self):
639+
calls = []
640+
641+
class _BrokenRewindStream(io.BytesIO):
642+
def seek(self, position, *args, **kwargs):
643+
raise OSError("seek failed")
644+
645+
class _Uploader:
646+
def __init__(self, metadata):
647+
self.metadata = metadata
648+
649+
def upload(self):
650+
calls.append(("upload", dict(self.metadata)))
651+
652+
class _TusClient:
653+
def __init__(self, tus_url):
654+
calls.append(("client", tus_url))
655+
656+
def uploader(self, **kwargs):
657+
calls.append(("uploader", dict(kwargs["metadata"])))
658+
return _Uploader(kwargs["metadata"])
659+
660+
async with AsyncTransloadit("key", "secret", service=self.server.base_url) as client:
661+
assembly = client.new_assembly()
662+
upload = _BrokenRewindStream(b"payload")
663+
upload.name = "payload.bin"
664+
assembly.add_file(upload)
665+
666+
rate_limited = Response(
667+
data={
668+
"error": "RATE_LIMIT_REACHED",
669+
"info": {"retryIn": 0},
670+
},
671+
status_code=200,
672+
headers={},
673+
)
674+
success = Response(
675+
data={
676+
"assembly_ssl_url": f"{self.server.base_url}/assemblies/assembly-123",
677+
"tus_url": f"{self.server.base_url}/uploads",
678+
},
679+
status_code=200,
680+
headers={},
681+
)
682+
683+
with mock.patch.object(
684+
client.request,
685+
"post",
686+
new=mock.AsyncMock(side_effect=[rate_limited, success]),
687+
) as post_mock:
688+
with mock.patch("asyncio.sleep", new_callable=mock.AsyncMock):
689+
with mock.patch("asyncio.to_thread", new=mock.AsyncMock(side_effect=lambda func, *args: func(*args))) as to_thread_mock:
690+
with mock.patch("transloadit.async_assembly.tus.TusClient", new=_TusClient):
691+
response = await assembly.create(resumable=True, retries=2)
692+
693+
self.assertEqual(response.data["assembly_ssl_url"], f"{self.server.base_url}/assemblies/assembly-123")
694+
self.assertEqual(post_mock.await_count, 2)
695+
self.assertEqual(to_thread_mock.await_count, 1)
696+
self.assertEqual(calls[0], ("client", f"{self.server.base_url}/uploads"))
697+
self.assertEqual(calls[1][0], "uploader")
698+
638699
async def test_async_assembly_resumable_rate_limit_returns_response_without_upload_when_retries_exhausted(self):
639700
calls = []
640701

@@ -885,6 +946,31 @@ async def fake_post(path, data=None, extra_data=None, files=None):
885946
self.assertEqual(reads, [b"payload", b"payload"])
886947
sleep_mock.assert_awaited_once_with(0)
887948

949+
async def test_async_assembly_non_resumable_rate_limit_raises_when_rewind_fails(self):
950+
class _BrokenRewindStream(io.BytesIO):
951+
def seek(self, position, *args, **kwargs):
952+
raise OSError("seek failed")
953+
954+
async with AsyncTransloadit("key", "secret", service=self.server.base_url) as client:
955+
assembly = client.new_assembly()
956+
assembly.add_file(_BrokenRewindStream(b"payload"))
957+
958+
rate_limited = Response(
959+
data={
960+
"error": "RATE_LIMIT_REACHED",
961+
"info": {"retryIn": 0},
962+
},
963+
status_code=200,
964+
headers={},
965+
)
966+
967+
with mock.patch.object(client.request, "post", new=mock.AsyncMock(return_value=rate_limited)) as post_mock:
968+
with mock.patch("asyncio.sleep", new_callable=mock.AsyncMock):
969+
with self.assertRaises(RuntimeError):
970+
await assembly.create(resumable=False, retries=1)
971+
972+
post_mock.assert_awaited_once()
973+
888974
async def test_async_request_uses_connect_and_read_timeouts_for_uploads(self):
889975
session = _RecordingSession({"ok": "ASSEMBLY_COMPLETED"})
890976
client = AsyncTransloadit("key", "secret", service=self.server.base_url, session=session)
@@ -985,4 +1071,6 @@ def test_async_assembly_helpers_cover_duplicate_names_and_rewind_edges(self):
9851071
positions = assembly._snapshot_file_positions()
9861072
self.assertNotIn("broken", positions)
9871073

988-
assembly._rewind_files({"missing": 4, "broken": 7})
1074+
assembly._rewind_files({"missing": 4})
1075+
with self.assertRaises(RuntimeError):
1076+
assembly._rewind_files({"broken": 7})

transloadit/async_assembly.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ def _rewind_files(self, positions):
5757
continue
5858
try:
5959
file_stream.seek(position)
60-
except (AttributeError, OSError, ValueError):
61-
continue
60+
except (AttributeError, OSError, ValueError) as exc:
61+
raise RuntimeError(f"Unable to rewind file stream {key!r}.") from exc
6262

6363
def _do_tus_upload(self, assembly_url, tus_url, retries):
6464
tus_client = tus.TusClient(tus_url)
@@ -102,7 +102,8 @@ async def create(self, wait=False, resumable=True, retries=3):
102102
if self._rate_limit_reached(response_data):
103103
if retries:
104104
await asyncio.sleep(response_data.get("info", {}).get("retryIn", 1))
105-
self._rewind_files(file_positions)
105+
if not resumable:
106+
self._rewind_files(file_positions)
106107
return await self.create(wait, resumable, retries - 1)
107108
return response
108109

0 commit comments

Comments
 (0)