Skip to content

Commit cb4e640

Browse files
committed
feat: integrate _StreamMultiplexer into AsyncMultiRangeDownloader
1 parent cf7c949 commit cb4e640

File tree

2 files changed

+238
-222
lines changed

2 files changed

+238
-222
lines changed

packages/google-cloud-storage/google/cloud/storage/asyncio/async_multi_range_downloader.py

Lines changed: 121 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@
2525

2626
from google.cloud import _storage_v2
2727
from google.cloud.storage._helpers import generate_random_56_bit_integer
28+
from google.cloud.storage.asyncio._stream_multiplexer import (
29+
_StreamEnd,
30+
_StreamError,
31+
_StreamMultiplexer,
32+
)
2833
from google.cloud.storage.asyncio.async_grpc_client import (
2934
AsyncGrpcClient,
3035
)
@@ -224,9 +229,7 @@ def __init__(
224229
self.read_obj_str: Optional[_AsyncReadObjectStream] = None
225230
self._is_stream_open: bool = False
226231
self._routing_token: Optional[str] = None
227-
self._read_id_to_writable_buffer_dict = {}
228-
self._read_id_to_download_ranges_id = {}
229-
self._download_ranges_id_to_pending_read_ids = {}
232+
self._multiplexer: Optional[_StreamMultiplexer] = None
230233
self.persisted_size: Optional[int] = None # updated after opening the stream
231234
self._open_retries: int = 0
232235

@@ -328,6 +331,45 @@ async def _do_open():
328331
self._is_stream_open = True
329332

330333
await retry_policy(_do_open)()
334+
self._multiplexer = _StreamMultiplexer(self.read_obj_str)
335+
336+
def _create_stream_factory(self, state, metadata):
337+
"""Create a factory that opens a new stream with current routing state."""
338+
339+
async def factory():
340+
current_handle = state.get("read_handle")
341+
current_token = state.get("routing_token")
342+
343+
stream = _AsyncReadObjectStream(
344+
client=self.client.grpc_client,
345+
bucket_name=self.bucket_name,
346+
object_name=self.object_name,
347+
generation_number=self.generation,
348+
read_handle=current_handle,
349+
)
350+
351+
current_metadata = list(metadata) if metadata else []
352+
if current_token:
353+
current_metadata.append(
354+
(
355+
"x-goog-request-params",
356+
f"routing_token={current_token}",
357+
)
358+
)
359+
360+
await stream.open(metadata=current_metadata if current_metadata else None)
361+
362+
if stream.generation_number:
363+
self.generation = stream.generation_number
364+
if stream.read_handle:
365+
self.read_handle = stream.read_handle
366+
367+
self.read_obj_str = stream
368+
self._is_stream_open = True
369+
370+
return stream
371+
372+
return factory
331373

332374
async def download_ranges(
333375
self,
@@ -353,32 +395,8 @@ async def download_ranges(
353395
* (0, 0, buffer) : downloads 0 to end , i.e. entire object.
354396
* (100, 0, buffer) : downloads from 100 to end.
355397
356-
357398
:type lock: asyncio.Lock
358-
:param lock: (Optional) An asyncio lock to synchronize sends and recvs
359-
on the underlying bidi-GRPC stream. This is required when multiple
360-
coroutines are calling this method concurrently.
361-
362-
i.e. Example usage with multiple coroutines:
363-
364-
```
365-
lock = asyncio.Lock()
366-
task1 = asyncio.create_task(mrd.download_ranges(ranges1, lock))
367-
task2 = asyncio.create_task(mrd.download_ranges(ranges2, lock))
368-
await asyncio.gather(task1, task2)
369-
370-
```
371-
372-
If user want to call this method serially from multiple coroutines,
373-
then providing a lock is not necessary.
374-
375-
```
376-
await mrd.download_ranges(ranges1)
377-
await mrd.download_ranges(ranges2)
378-
379-
# ... some other code code...
380-
381-
```
399+
:param lock: (Deprecated) This parameter is deprecated and has no effect.
382400
383401
:type retry_policy: :class:`~google.api_core.retry_async.AsyncRetry`
384402
:param retry_policy: (Optional) The retry policy to use for the operation.
@@ -397,9 +415,6 @@ async def download_ranges(
397415
if not self._is_stream_open:
398416
raise ValueError("Underlying bidi-gRPC stream is not open")
399417

400-
if lock is None:
401-
lock = asyncio.Lock()
402-
403418
if retry_policy is None:
404419
retry_policy = AsyncRetry(predicate=_is_read_retryable)
405420

@@ -419,99 +434,97 @@ async def download_ranges(
419434
"routing_token": None,
420435
}
421436

422-
# Track attempts to manage stream reuse
423-
attempt_count = 0
424-
425-
def send_ranges_and_get_bytes(
426-
requests: List[_storage_v2.ReadRange],
427-
state: Dict[str, Any],
428-
metadata: Optional[List[Tuple[str, str]]] = None,
429-
):
430-
async def generator():
431-
nonlocal attempt_count
432-
attempt_count += 1
433-
434-
if attempt_count > 1:
435-
logger.info(
436-
f"Resuming download (attempt {attempt_count}) for {len(requests)} ranges."
437-
)
437+
read_ids = set(download_states.keys())
438+
queue = self._multiplexer.register(read_ids)
438439

439-
async with lock:
440-
current_handle = state.get("read_handle")
441-
current_token = state.get("routing_token")
440+
try:
441+
attempt_count = 0
442+
last_broken_generation = None
442443

443-
# We reopen if it's a redirect (token exists) OR if this is a retry
444-
# (not first attempt). This prevents trying to send data on a dead
445-
# stream from a previous failed attempt.
446-
should_reopen = (
447-
(attempt_count > 1)
448-
or (current_token is not None)
449-
or (metadata is not None)
450-
)
444+
def send_and_recv_via_multiplexer(
445+
requests: List[_storage_v2.ReadRange],
446+
state: Dict[str, Any],
447+
):
448+
async def generator():
449+
nonlocal attempt_count, last_broken_generation
450+
attempt_count += 1
451451

452-
if should_reopen:
453-
if current_token:
454-
logger.info(
455-
f"Re-opening stream with routing token: {current_token}"
456-
)
457-
458-
self.read_obj_str = _AsyncReadObjectStream(
459-
client=self.client.grpc_client,
460-
bucket_name=self.bucket_name,
461-
object_name=self.object_name,
462-
generation_number=self.generation,
463-
read_handle=current_handle,
452+
if attempt_count > 1:
453+
logger.info(
454+
f"Resuming download (attempt {attempt_count}) for {len(requests)} ranges."
464455
)
465456

466-
# Inject routing_token into metadata if present
467-
current_metadata = list(metadata) if metadata else []
468-
if current_token:
469-
current_metadata.append(
470-
(
471-
"x-goog-request-params",
472-
f"routing_token={current_token}",
473-
)
474-
)
475-
476-
await self.read_obj_str.open(
477-
metadata=current_metadata if current_metadata else None
457+
# Reopen stream if needed
458+
should_reopen = (
459+
attempt_count > 1 and last_broken_generation is not None
460+
) or (attempt_count == 1 and metadata is not None)
461+
if should_reopen:
462+
broken_gen = (
463+
last_broken_generation
464+
if attempt_count > 1
465+
else self._multiplexer.stream_generation
466+
)
467+
stream_factory = self._create_stream_factory(state, metadata)
468+
await self._multiplexer.reopen_stream(
469+
broken_gen, stream_factory
478470
)
479-
self._is_stream_open = True
480471

481-
pending_read_ids = {r.read_id for r in requests}
472+
my_generation = self._multiplexer.stream_generation
482473

483474
# Send Requests
475+
pending_read_ids = {r.read_id for r in requests}
484476
for i in range(
485477
0, len(requests), _MAX_READ_RANGES_PER_BIDI_READ_REQUEST
486478
):
487479
batch = requests[i : i + _MAX_READ_RANGES_PER_BIDI_READ_REQUEST]
488-
await self.read_obj_str.send(
489-
_storage_v2.BidiReadObjectRequest(read_ranges=batch)
490-
)
480+
try:
481+
await self._multiplexer.send(
482+
_storage_v2.BidiReadObjectRequest(read_ranges=batch)
483+
)
484+
except Exception:
485+
last_broken_generation = my_generation
486+
raise
491487

488+
# Receive Responses
492489
while pending_read_ids:
493-
response = await self.read_obj_str.recv()
494-
if response is None:
490+
item = await queue.get()
491+
492+
if isinstance(item, _StreamEnd):
493+
if pending_read_ids:
494+
last_broken_generation = my_generation
495+
raise exceptions.ServiceUnavailable(
496+
"Stream ended with pending read_ids"
497+
)
495498
break
496-
if response.object_data_ranges:
497-
for data_range in response.object_data_ranges:
499+
500+
if isinstance(item, _StreamError):
501+
if item.generation < my_generation:
502+
continue # stale error, skip
503+
last_broken_generation = item.generation
504+
raise item.exception
505+
506+
# Track completion
507+
if item.object_data_ranges:
508+
for data_range in item.object_data_ranges:
498509
if data_range.range_end:
499510
pending_read_ids.discard(
500511
data_range.read_range.read_id
501512
)
502-
yield response
513+
yield item
503514

504-
return generator()
515+
return generator()
505516

506-
strategy = _ReadResumptionStrategy()
507-
retry_manager = _BidiStreamRetryManager(
508-
strategy, lambda r, s: send_ranges_and_get_bytes(r, s, metadata=metadata)
509-
)
517+
strategy = _ReadResumptionStrategy()
518+
retry_manager = _BidiStreamRetryManager(
519+
strategy, send_and_recv_via_multiplexer
520+
)
510521

511-
await retry_manager.execute(initial_state, retry_policy)
522+
await retry_manager.execute(initial_state, retry_policy)
512523

513-
if initial_state.get("read_handle"):
514-
self.read_handle = initial_state["read_handle"]
524+
if initial_state.get("read_handle"):
525+
self.read_handle = initial_state["read_handle"]
526+
finally:
527+
self._multiplexer.unregister(read_ids)
515528

516529
async def close(self):
517530
"""
@@ -520,8 +533,15 @@ async def close(self):
520533
if not self._is_stream_open:
521534
raise ValueError("Underlying bidi-gRPC stream is not open")
522535

536+
if self._multiplexer:
537+
await self._multiplexer.close()
538+
self._multiplexer = None
539+
523540
if self.read_obj_str:
524-
await self.read_obj_str.close()
541+
try:
542+
await self.read_obj_str.close()
543+
except (asyncio.CancelledError, exceptions.GoogleAPICallError):
544+
pass
525545
self.read_obj_str = None
526546
self._is_stream_open = False
527547

0 commit comments

Comments
 (0)