Skip to content

Commit 3b93ce2

Browse files
committed
feat: integrate _StreamMultiplexer into AsyncMultiRangeDownloader
1 parent da7ac6e commit 3b93ce2

File tree

2 files changed

+244
-222
lines changed

2 files changed

+244
-222
lines changed

packages/google-cloud-storage/google/cloud/storage/asyncio/async_multi_range_downloader.py

Lines changed: 126 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@
3939
_DownloadState,
4040
_ReadResumptionStrategy,
4141
)
42+
from google.cloud.storage.asyncio._stream_multiplexer import (
43+
_StreamMultiplexer,
44+
_StreamError,
45+
_StreamEnd,
46+
)
47+
4248

4349
from ._utils import raise_if_no_fast_crc32c
4450

@@ -224,9 +230,7 @@ def __init__(
224230
self.read_obj_str: Optional[_AsyncReadObjectStream] = None
225231
self._is_stream_open: bool = False
226232
self._routing_token: Optional[str] = None
227-
self._read_id_to_writable_buffer_dict = {}
228-
self._read_id_to_download_ranges_id = {}
229-
self._download_ranges_id_to_pending_read_ids = {}
233+
self._multiplexer: Optional[_StreamMultiplexer] = None
230234
self.persisted_size: Optional[int] = None # updated after opening the stream
231235
self._open_retries: int = 0
232236

@@ -328,6 +332,47 @@ async def _do_open():
328332
self._is_stream_open = True
329333

330334
await retry_policy(_do_open)()
335+
self._multiplexer = _StreamMultiplexer(self.read_obj_str)
336+
337+
def _create_stream_factory(self, state, metadata):
338+
"""Create a factory that opens a new stream with current routing state."""
339+
340+
async def factory():
341+
current_handle = state.get("read_handle")
342+
current_token = state.get("routing_token")
343+
344+
stream = _AsyncReadObjectStream(
345+
client=self.client.grpc_client,
346+
bucket_name=self.bucket_name,
347+
object_name=self.object_name,
348+
generation_number=self.generation,
349+
read_handle=current_handle,
350+
)
351+
352+
current_metadata = list(metadata) if metadata else []
353+
if current_token:
354+
current_metadata.append(
355+
(
356+
"x-goog-request-params",
357+
f"routing_token={current_token}",
358+
)
359+
)
360+
361+
await stream.open(
362+
metadata=current_metadata if current_metadata else None
363+
)
364+
365+
if stream.generation_number:
366+
self.generation = stream.generation_number
367+
if stream.read_handle:
368+
self.read_handle = stream.read_handle
369+
370+
self.read_obj_str = stream
371+
self._is_stream_open = True
372+
373+
return stream
374+
375+
return factory
331376

332377
async def download_ranges(
333378
self,
@@ -353,32 +398,8 @@ async def download_ranges(
353398
* (0, 0, buffer) : downloads 0 to end , i.e. entire object.
354399
* (100, 0, buffer) : downloads from 100 to end.
355400
356-
357401
:type lock: asyncio.Lock
358-
:param lock: (Optional) An asyncio lock to synchronize sends and recvs
359-
on the underlying bidi-GRPC stream. This is required when multiple
360-
coroutines are calling this method concurrently.
361-
362-
i.e. Example usage with multiple coroutines:
363-
364-
```
365-
lock = asyncio.Lock()
366-
task1 = asyncio.create_task(mrd.download_ranges(ranges1, lock))
367-
task2 = asyncio.create_task(mrd.download_ranges(ranges2, lock))
368-
await asyncio.gather(task1, task2)
369-
370-
```
371-
372-
If user want to call this method serially from multiple coroutines,
373-
then providing a lock is not necessary.
374-
375-
```
376-
await mrd.download_ranges(ranges1)
377-
await mrd.download_ranges(ranges2)
378-
379-
# ... some other code code...
380-
381-
```
402+
:param lock: (Deprecated) This parameter is deprecated and has no effect.
382403
383404
:type retry_policy: :class:`~google.api_core.retry_async.AsyncRetry`
384405
:param retry_policy: (Optional) The retry policy to use for the operation.
@@ -397,9 +418,6 @@ async def download_ranges(
397418
if not self._is_stream_open:
398419
raise ValueError("Underlying bidi-gRPC stream is not open")
399420

400-
if lock is None:
401-
lock = asyncio.Lock()
402-
403421
if retry_policy is None:
404422
retry_policy = AsyncRetry(predicate=_is_read_retryable)
405423

@@ -419,99 +437,98 @@ async def download_ranges(
419437
"routing_token": None,
420438
}
421439

422-
# Track attempts to manage stream reuse
423-
attempt_count = 0
424-
425-
def send_ranges_and_get_bytes(
426-
requests: List[_storage_v2.ReadRange],
427-
state: Dict[str, Any],
428-
metadata: Optional[List[Tuple[str, str]]] = None,
429-
):
430-
async def generator():
431-
nonlocal attempt_count
432-
attempt_count += 1
433-
434-
if attempt_count > 1:
435-
logger.info(
436-
f"Resuming download (attempt {attempt_count}) for {len(requests)} ranges."
437-
)
438-
439-
async with lock:
440-
current_handle = state.get("read_handle")
441-
current_token = state.get("routing_token")
440+
read_ids = set(download_states.keys())
441+
queue = self._multiplexer.register(read_ids)
442442

443-
# We reopen if it's a redirect (token exists) OR if this is a retry
444-
# (not first attempt). This prevents trying to send data on a dead
445-
# stream from a previous failed attempt.
446-
should_reopen = (
447-
(attempt_count > 1)
448-
or (current_token is not None)
449-
or (metadata is not None)
450-
)
443+
try:
444+
attempt_count = 0
445+
last_broken_generation = None
451446

452-
if should_reopen:
453-
if current_token:
454-
logger.info(
455-
f"Re-opening stream with routing token: {current_token}"
456-
)
447+
def send_and_recv_via_multiplexer(
448+
requests: List[_storage_v2.ReadRange],
449+
state: Dict[str, Any],
450+
):
451+
async def generator():
452+
nonlocal attempt_count, last_broken_generation
453+
attempt_count += 1
457454

458-
self.read_obj_str = _AsyncReadObjectStream(
459-
client=self.client.grpc_client,
460-
bucket_name=self.bucket_name,
461-
object_name=self.object_name,
462-
generation_number=self.generation,
463-
read_handle=current_handle,
455+
if attempt_count > 1:
456+
logger.info(
457+
f"Resuming download (attempt {attempt_count}) for {len(requests)} ranges."
464458
)
465459

466-
# Inject routing_token into metadata if present
467-
current_metadata = list(metadata) if metadata else []
468-
if current_token:
469-
current_metadata.append(
470-
(
471-
"x-goog-request-params",
472-
f"routing_token={current_token}",
473-
)
474-
)
475-
476-
await self.read_obj_str.open(
477-
metadata=current_metadata if current_metadata else None
460+
# Reopen stream if needed
461+
should_reopen = (
462+
(attempt_count > 1 and last_broken_generation is not None)
463+
or (attempt_count == 1 and metadata is not None)
464+
)
465+
if should_reopen:
466+
broken_gen = (
467+
last_broken_generation
468+
if attempt_count > 1
469+
else self._multiplexer.stream_generation
478470
)
479-
self._is_stream_open = True
471+
stream_factory = self._create_stream_factory(state, metadata)
472+
await self._multiplexer.reopen_stream(broken_gen, stream_factory)
480473

481-
pending_read_ids = {r.read_id for r in requests}
474+
my_generation = self._multiplexer.stream_generation
482475

483476
# Send Requests
477+
pending_read_ids = {r.read_id for r in requests}
484478
for i in range(
485479
0, len(requests), _MAX_READ_RANGES_PER_BIDI_READ_REQUEST
486480
):
487-
batch = requests[i : i + _MAX_READ_RANGES_PER_BIDI_READ_REQUEST]
488-
await self.read_obj_str.send(
489-
_storage_v2.BidiReadObjectRequest(read_ranges=batch)
490-
)
481+
batch = requests[
482+
i : i + _MAX_READ_RANGES_PER_BIDI_READ_REQUEST
483+
]
484+
try:
485+
await self._multiplexer.send(
486+
_storage_v2.BidiReadObjectRequest(read_ranges=batch)
487+
)
488+
except Exception:
489+
last_broken_generation = my_generation
490+
raise
491491

492+
# Receive Responses
492493
while pending_read_ids:
493-
response = await self.read_obj_str.recv()
494-
if response is None:
494+
item = await queue.get()
495+
496+
if isinstance(item, _StreamEnd):
497+
if pending_read_ids:
498+
last_broken_generation = my_generation
499+
raise exceptions.ServiceUnavailable(
500+
"Stream ended with pending read_ids"
501+
)
495502
break
496-
if response.object_data_ranges:
497-
for data_range in response.object_data_ranges:
503+
504+
if isinstance(item, _StreamError):
505+
if item.generation < my_generation:
506+
continue # stale error, skip
507+
last_broken_generation = item.generation
508+
raise item.exception
509+
510+
# Track completion
511+
if item.object_data_ranges:
512+
for data_range in item.object_data_ranges:
498513
if data_range.range_end:
499514
pending_read_ids.discard(
500515
data_range.read_range.read_id
501516
)
502-
yield response
517+
yield item
503518

504-
return generator()
519+
return generator()
505520

506-
strategy = _ReadResumptionStrategy()
507-
retry_manager = _BidiStreamRetryManager(
508-
strategy, lambda r, s: send_ranges_and_get_bytes(r, s, metadata=metadata)
509-
)
521+
strategy = _ReadResumptionStrategy()
522+
retry_manager = _BidiStreamRetryManager(
523+
strategy, send_and_recv_via_multiplexer
524+
)
510525

511-
await retry_manager.execute(initial_state, retry_policy)
526+
await retry_manager.execute(initial_state, retry_policy)
512527

513-
if initial_state.get("read_handle"):
514-
self.read_handle = initial_state["read_handle"]
528+
if initial_state.get("read_handle"):
529+
self.read_handle = initial_state["read_handle"]
530+
finally:
531+
self._multiplexer.unregister(read_ids)
515532

516533
async def close(self):
517534
"""
@@ -520,8 +537,15 @@ async def close(self):
520537
if not self._is_stream_open:
521538
raise ValueError("Underlying bidi-gRPC stream is not open")
522539

540+
if self._multiplexer:
541+
await self._multiplexer.close()
542+
self._multiplexer = None
543+
523544
if self.read_obj_str:
524-
await self.read_obj_str.close()
545+
try:
546+
await self.read_obj_str.close()
547+
except (asyncio.CancelledError, exceptions.GoogleAPICallError):
548+
pass
525549
self.read_obj_str = None
526550
self._is_stream_open = False
527551

0 commit comments

Comments
 (0)