2525
2626from google .cloud import _storage_v2
2727from google .cloud .storage ._helpers import generate_random_56_bit_integer
28+ from google .cloud .storage .asyncio ._stream_multiplexer import (
29+ _StreamEnd ,
30+ _StreamError ,
31+ _StreamMultiplexer ,
32+ )
2833from google .cloud .storage .asyncio .async_grpc_client import (
2934 AsyncGrpcClient ,
3035)
@@ -224,9 +229,7 @@ def __init__(
224229 self .read_obj_str : Optional [_AsyncReadObjectStream ] = None
225230 self ._is_stream_open : bool = False
226231 self ._routing_token : Optional [str ] = None
227- self ._read_id_to_writable_buffer_dict = {}
228- self ._read_id_to_download_ranges_id = {}
229- self ._download_ranges_id_to_pending_read_ids = {}
232+ self ._multiplexer : Optional [_StreamMultiplexer ] = None
230233 self .persisted_size : Optional [int ] = None # updated after opening the stream
231234 self ._open_retries : int = 0
232235
@@ -328,6 +331,45 @@ async def _do_open():
328331 self ._is_stream_open = True
329332
330333 await retry_policy (_do_open )()
334+ self ._multiplexer = _StreamMultiplexer (self .read_obj_str )
335+
336+ def _create_stream_factory (self , state , metadata ):
337+ """Create a factory that opens a new stream with current routing state."""
338+
339+ async def factory ():
340+ current_handle = state .get ("read_handle" )
341+ current_token = state .get ("routing_token" )
342+
343+ stream = _AsyncReadObjectStream (
344+ client = self .client .grpc_client ,
345+ bucket_name = self .bucket_name ,
346+ object_name = self .object_name ,
347+ generation_number = self .generation ,
348+ read_handle = current_handle ,
349+ )
350+
351+ current_metadata = list (metadata ) if metadata else []
352+ if current_token :
353+ current_metadata .append (
354+ (
355+ "x-goog-request-params" ,
356+ f"routing_token={ current_token } " ,
357+ )
358+ )
359+
360+ await stream .open (metadata = current_metadata if current_metadata else None )
361+
362+ if stream .generation_number :
363+ self .generation = stream .generation_number
364+ if stream .read_handle :
365+ self .read_handle = stream .read_handle
366+
367+ self .read_obj_str = stream
368+ self ._is_stream_open = True
369+
370+ return stream
371+
372+ return factory
331373
332374 async def download_ranges (
333375 self ,
@@ -353,32 +395,8 @@ async def download_ranges(
353395 * (0, 0, buffer) : downloads 0 to end , i.e. entire object.
354396 * (100, 0, buffer) : downloads from 100 to end.
355397
356-
357398 :type lock: asyncio.Lock
358- :param lock: (Optional) An asyncio lock to synchronize sends and recvs
359- on the underlying bidi-GRPC stream. This is required when multiple
360- coroutines are calling this method concurrently.
361-
362- i.e. Example usage with multiple coroutines:
363-
364- ```
365- lock = asyncio.Lock()
366- task1 = asyncio.create_task(mrd.download_ranges(ranges1, lock))
367- task2 = asyncio.create_task(mrd.download_ranges(ranges2, lock))
368- await asyncio.gather(task1, task2)
369-
370- ```
371-
372- If user want to call this method serially from multiple coroutines,
373- then providing a lock is not necessary.
374-
375- ```
376- await mrd.download_ranges(ranges1)
377- await mrd.download_ranges(ranges2)
378-
379- # ... some other code code...
380-
381- ```
399+ :param lock: (Deprecated) This parameter is deprecated and has no effect.
382400
383401 :type retry_policy: :class:`~google.api_core.retry_async.AsyncRetry`
384402 :param retry_policy: (Optional) The retry policy to use for the operation.
@@ -397,9 +415,6 @@ async def download_ranges(
397415 if not self ._is_stream_open :
398416 raise ValueError ("Underlying bidi-gRPC stream is not open" )
399417
400- if lock is None :
401- lock = asyncio .Lock ()
402-
403418 if retry_policy is None :
404419 retry_policy = AsyncRetry (predicate = _is_read_retryable )
405420
@@ -419,99 +434,97 @@ async def download_ranges(
419434 "routing_token" : None ,
420435 }
421436
422- # Track attempts to manage stream reuse
423- attempt_count = 0
424-
425- def send_ranges_and_get_bytes (
426- requests : List [_storage_v2 .ReadRange ],
427- state : Dict [str , Any ],
428- metadata : Optional [List [Tuple [str , str ]]] = None ,
429- ):
430- async def generator ():
431- nonlocal attempt_count
432- attempt_count += 1
433-
434- if attempt_count > 1 :
435- logger .info (
436- f"Resuming download (attempt { attempt_count } ) for { len (requests )} ranges."
437- )
437+ read_ids = set (download_states .keys ())
438+ queue = self ._multiplexer .register (read_ids )
438439
439- async with lock :
440- current_handle = state . get ( "read_handle" )
441- current_token = state . get ( "routing_token" )
440+ try :
441+ attempt_count = 0
442+ last_broken_generation = None
442443
443- # We reopen if it's a redirect (token exists) OR if this is a retry
444- # (not first attempt). This prevents trying to send data on a dead
445- # stream from a previous failed attempt.
446- should_reopen = (
447- (attempt_count > 1 )
448- or (current_token is not None )
449- or (metadata is not None )
450- )
444+ def send_and_recv_via_multiplexer (
445+ requests : List [_storage_v2 .ReadRange ],
446+ state : Dict [str , Any ],
447+ ):
448+ async def generator ():
449+ nonlocal attempt_count , last_broken_generation
450+ attempt_count += 1
451451
452- if should_reopen :
453- if current_token :
454- logger .info (
455- f"Re-opening stream with routing token: { current_token } "
456- )
457-
458- self .read_obj_str = _AsyncReadObjectStream (
459- client = self .client .grpc_client ,
460- bucket_name = self .bucket_name ,
461- object_name = self .object_name ,
462- generation_number = self .generation ,
463- read_handle = current_handle ,
452+ if attempt_count > 1 :
453+ logger .info (
454+ f"Resuming download (attempt { attempt_count } ) for { len (requests )} ranges."
464455 )
465456
466- # Inject routing_token into metadata if present
467- current_metadata = list (metadata ) if metadata else []
468- if current_token :
469- current_metadata .append (
470- (
471- "x-goog-request-params" ,
472- f"routing_token={ current_token } " ,
473- )
474- )
475-
476- await self .read_obj_str .open (
477- metadata = current_metadata if current_metadata else None
457+ # Reopen stream if needed
458+ should_reopen = (
459+ attempt_count > 1 and last_broken_generation is not None
460+ ) or (attempt_count == 1 and metadata is not None )
461+ if should_reopen :
462+ broken_gen = (
463+ last_broken_generation
464+ if attempt_count > 1
465+ else self ._multiplexer .stream_generation
466+ )
467+ stream_factory = self ._create_stream_factory (state , metadata )
468+ await self ._multiplexer .reopen_stream (
469+ broken_gen , stream_factory
478470 )
479- self ._is_stream_open = True
480471
481- pending_read_ids = { r . read_id for r in requests }
472+ my_generation = self . _multiplexer . stream_generation
482473
483474 # Send Requests
475+ pending_read_ids = {r .read_id for r in requests }
484476 for i in range (
485477 0 , len (requests ), _MAX_READ_RANGES_PER_BIDI_READ_REQUEST
486478 ):
487479 batch = requests [i : i + _MAX_READ_RANGES_PER_BIDI_READ_REQUEST ]
488- await self .read_obj_str .send (
489- _storage_v2 .BidiReadObjectRequest (read_ranges = batch )
490- )
480+ try :
481+ await self ._multiplexer .send (
482+ _storage_v2 .BidiReadObjectRequest (read_ranges = batch )
483+ )
484+ except Exception :
485+ last_broken_generation = my_generation
486+ raise
491487
488+ # Receive Responses
492489 while pending_read_ids :
493- response = await self .read_obj_str .recv ()
494- if response is None :
490+ item = await queue .get ()
491+
492+ if isinstance (item , _StreamEnd ):
493+ if pending_read_ids :
494+ last_broken_generation = my_generation
495+ raise exceptions .ServiceUnavailable (
496+ "Stream ended with pending read_ids"
497+ )
495498 break
496- if response .object_data_ranges :
497- for data_range in response .object_data_ranges :
499+
500+ if isinstance (item , _StreamError ):
501+ if item .generation < my_generation :
502+ continue # stale error, skip
503+ last_broken_generation = item .generation
504+ raise item .exception
505+
506+ # Track completion
507+ if item .object_data_ranges :
508+ for data_range in item .object_data_ranges :
498509 if data_range .range_end :
499510 pending_read_ids .discard (
500511 data_range .read_range .read_id
501512 )
502- yield response
513+ yield item
503514
504- return generator ()
515+ return generator ()
505516
506- strategy = _ReadResumptionStrategy ()
507- retry_manager = _BidiStreamRetryManager (
508- strategy , lambda r , s : send_ranges_and_get_bytes ( r , s , metadata = metadata )
509- )
517+ strategy = _ReadResumptionStrategy ()
518+ retry_manager = _BidiStreamRetryManager (
519+ strategy , send_and_recv_via_multiplexer
520+ )
510521
511- await retry_manager .execute (initial_state , retry_policy )
522+ await retry_manager .execute (initial_state , retry_policy )
512523
513- if initial_state .get ("read_handle" ):
514- self .read_handle = initial_state ["read_handle" ]
524+ if initial_state .get ("read_handle" ):
525+ self .read_handle = initial_state ["read_handle" ]
526+ finally :
527+ self ._multiplexer .unregister (read_ids )
515528
516529 async def close (self ):
517530 """
@@ -520,8 +533,15 @@ async def close(self):
520533 if not self ._is_stream_open :
521534 raise ValueError ("Underlying bidi-gRPC stream is not open" )
522535
536+ if self ._multiplexer :
537+ await self ._multiplexer .close ()
538+ self ._multiplexer = None
539+
523540 if self .read_obj_str :
524- await self .read_obj_str .close ()
541+ try :
542+ await self .read_obj_str .close ()
543+ except (asyncio .CancelledError , exceptions .GoogleAPICallError ):
544+ pass
525545 self .read_obj_str = None
526546 self ._is_stream_open = False
527547
0 commit comments