3939 _DownloadState ,
4040 _ReadResumptionStrategy ,
4141)
42+ from google .cloud .storage .asyncio ._stream_multiplexer import (
43+ _StreamMultiplexer ,
44+ _StreamError ,
45+ _StreamEnd ,
46+ )
47+
4248
4349from ._utils import raise_if_no_fast_crc32c
4450
@@ -224,9 +230,7 @@ def __init__(
224230 self .read_obj_str : Optional [_AsyncReadObjectStream ] = None
225231 self ._is_stream_open : bool = False
226232 self ._routing_token : Optional [str ] = None
227- self ._read_id_to_writable_buffer_dict = {}
228- self ._read_id_to_download_ranges_id = {}
229- self ._download_ranges_id_to_pending_read_ids = {}
233+ self ._multiplexer : Optional [_StreamMultiplexer ] = None
230234 self .persisted_size : Optional [int ] = None # updated after opening the stream
231235 self ._open_retries : int = 0
232236
@@ -328,6 +332,47 @@ async def _do_open():
328332 self ._is_stream_open = True
329333
330334 await retry_policy (_do_open )()
335+ self ._multiplexer = _StreamMultiplexer (self .read_obj_str )
336+
337+ def _create_stream_factory (self , state , metadata ):
338+ """Create a factory that opens a new stream with current routing state."""
339+
340+ async def factory ():
341+ current_handle = state .get ("read_handle" )
342+ current_token = state .get ("routing_token" )
343+
344+ stream = _AsyncReadObjectStream (
345+ client = self .client .grpc_client ,
346+ bucket_name = self .bucket_name ,
347+ object_name = self .object_name ,
348+ generation_number = self .generation ,
349+ read_handle = current_handle ,
350+ )
351+
352+ current_metadata = list (metadata ) if metadata else []
353+ if current_token :
354+ current_metadata .append (
355+ (
356+ "x-goog-request-params" ,
357+ f"routing_token={ current_token } " ,
358+ )
359+ )
360+
361+ await stream .open (
362+ metadata = current_metadata if current_metadata else None
363+ )
364+
365+ if stream .generation_number :
366+ self .generation = stream .generation_number
367+ if stream .read_handle :
368+ self .read_handle = stream .read_handle
369+
370+ self .read_obj_str = stream
371+ self ._is_stream_open = True
372+
373+ return stream
374+
375+ return factory
331376
332377 async def download_ranges (
333378 self ,
@@ -353,32 +398,8 @@ async def download_ranges(
353398 * (0, 0, buffer) : downloads 0 to end , i.e. entire object.
354399 * (100, 0, buffer) : downloads from 100 to end.
355400
356-
357401 :type lock: asyncio.Lock
358- :param lock: (Optional) An asyncio lock to synchronize sends and recvs
359- on the underlying bidi-GRPC stream. This is required when multiple
360- coroutines are calling this method concurrently.
361-
362- i.e. Example usage with multiple coroutines:
363-
364- ```
365- lock = asyncio.Lock()
366- task1 = asyncio.create_task(mrd.download_ranges(ranges1, lock))
367- task2 = asyncio.create_task(mrd.download_ranges(ranges2, lock))
368- await asyncio.gather(task1, task2)
369-
370- ```
371-
372- If user want to call this method serially from multiple coroutines,
373- then providing a lock is not necessary.
374-
375- ```
376- await mrd.download_ranges(ranges1)
377- await mrd.download_ranges(ranges2)
378-
379- # ... some other code code...
380-
381- ```
402+ :param lock: (Deprecated) This parameter is deprecated and has no effect.
382403
383404 :type retry_policy: :class:`~google.api_core.retry_async.AsyncRetry`
384405 :param retry_policy: (Optional) The retry policy to use for the operation.
@@ -397,9 +418,6 @@ async def download_ranges(
397418 if not self ._is_stream_open :
398419 raise ValueError ("Underlying bidi-gRPC stream is not open" )
399420
400- if lock is None :
401- lock = asyncio .Lock ()
402-
403421 if retry_policy is None :
404422 retry_policy = AsyncRetry (predicate = _is_read_retryable )
405423
@@ -419,99 +437,98 @@ async def download_ranges(
419437 "routing_token" : None ,
420438 }
421439
422- # Track attempts to manage stream reuse
423- attempt_count = 0
424-
425- def send_ranges_and_get_bytes (
426- requests : List [_storage_v2 .ReadRange ],
427- state : Dict [str , Any ],
428- metadata : Optional [List [Tuple [str , str ]]] = None ,
429- ):
430- async def generator ():
431- nonlocal attempt_count
432- attempt_count += 1
433-
434- if attempt_count > 1 :
435- logger .info (
436- f"Resuming download (attempt { attempt_count } ) for { len (requests )} ranges."
437- )
438-
439- async with lock :
440- current_handle = state .get ("read_handle" )
441- current_token = state .get ("routing_token" )
440+ read_ids = set (download_states .keys ())
441+ queue = self ._multiplexer .register (read_ids )
442442
443- # We reopen if it's a redirect (token exists) OR if this is a retry
444- # (not first attempt). This prevents trying to send data on a dead
445- # stream from a previous failed attempt.
446- should_reopen = (
447- (attempt_count > 1 )
448- or (current_token is not None )
449- or (metadata is not None )
450- )
443+ try :
444+ attempt_count = 0
445+ last_broken_generation = None
451446
452- if should_reopen :
453- if current_token :
454- logger .info (
455- f"Re-opening stream with routing token: { current_token } "
456- )
447+ def send_and_recv_via_multiplexer (
448+ requests : List [_storage_v2 .ReadRange ],
449+ state : Dict [str , Any ],
450+ ):
451+ async def generator ():
452+ nonlocal attempt_count , last_broken_generation
453+ attempt_count += 1
457454
458- self .read_obj_str = _AsyncReadObjectStream (
459- client = self .client .grpc_client ,
460- bucket_name = self .bucket_name ,
461- object_name = self .object_name ,
462- generation_number = self .generation ,
463- read_handle = current_handle ,
455+ if attempt_count > 1 :
456+ logger .info (
457+ f"Resuming download (attempt { attempt_count } ) for { len (requests )} ranges."
464458 )
465459
466- # Inject routing_token into metadata if present
467- current_metadata = list (metadata ) if metadata else []
468- if current_token :
469- current_metadata .append (
470- (
471- "x-goog-request-params" ,
472- f"routing_token={ current_token } " ,
473- )
474- )
475-
476- await self .read_obj_str .open (
477- metadata = current_metadata if current_metadata else None
460+ # Reopen stream if needed
461+ should_reopen = (
462+ (attempt_count > 1 and last_broken_generation is not None )
463+ or (attempt_count == 1 and metadata is not None )
464+ )
465+ if should_reopen :
466+ broken_gen = (
467+ last_broken_generation
468+ if attempt_count > 1
469+ else self ._multiplexer .stream_generation
478470 )
479- self ._is_stream_open = True
471+ stream_factory = self ._create_stream_factory (state , metadata )
472+ await self ._multiplexer .reopen_stream (broken_gen , stream_factory )
480473
481- pending_read_ids = { r . read_id for r in requests }
474+ my_generation = self . _multiplexer . stream_generation
482475
483476 # Send Requests
477+ pending_read_ids = {r .read_id for r in requests }
484478 for i in range (
485479 0 , len (requests ), _MAX_READ_RANGES_PER_BIDI_READ_REQUEST
486480 ):
487- batch = requests [i : i + _MAX_READ_RANGES_PER_BIDI_READ_REQUEST ]
488- await self .read_obj_str .send (
489- _storage_v2 .BidiReadObjectRequest (read_ranges = batch )
490- )
481+ batch = requests [
482+ i : i + _MAX_READ_RANGES_PER_BIDI_READ_REQUEST
483+ ]
484+ try :
485+ await self ._multiplexer .send (
486+ _storage_v2 .BidiReadObjectRequest (read_ranges = batch )
487+ )
488+ except Exception :
489+ last_broken_generation = my_generation
490+ raise
491491
492+ # Receive Responses
492493 while pending_read_ids :
493- response = await self .read_obj_str .recv ()
494- if response is None :
494+ item = await queue .get ()
495+
496+ if isinstance (item , _StreamEnd ):
497+ if pending_read_ids :
498+ last_broken_generation = my_generation
499+ raise exceptions .ServiceUnavailable (
500+ "Stream ended with pending read_ids"
501+ )
495502 break
496- if response .object_data_ranges :
497- for data_range in response .object_data_ranges :
503+
504+ if isinstance (item , _StreamError ):
505+ if item .generation < my_generation :
506+ continue # stale error, skip
507+ last_broken_generation = item .generation
508+ raise item .exception
509+
510+ # Track completion
511+ if item .object_data_ranges :
512+ for data_range in item .object_data_ranges :
498513 if data_range .range_end :
499514 pending_read_ids .discard (
500515 data_range .read_range .read_id
501516 )
502- yield response
517+ yield item
503518
504- return generator ()
519+ return generator ()
505520
506- strategy = _ReadResumptionStrategy ()
507- retry_manager = _BidiStreamRetryManager (
508- strategy , lambda r , s : send_ranges_and_get_bytes ( r , s , metadata = metadata )
509- )
521+ strategy = _ReadResumptionStrategy ()
522+ retry_manager = _BidiStreamRetryManager (
523+ strategy , send_and_recv_via_multiplexer
524+ )
510525
511- await retry_manager .execute (initial_state , retry_policy )
526+ await retry_manager .execute (initial_state , retry_policy )
512527
513- if initial_state .get ("read_handle" ):
514- self .read_handle = initial_state ["read_handle" ]
528+ if initial_state .get ("read_handle" ):
529+ self .read_handle = initial_state ["read_handle" ]
530+ finally :
531+ self ._multiplexer .unregister (read_ids )
515532
516533 async def close (self ):
517534 """
@@ -520,8 +537,15 @@ async def close(self):
520537 if not self ._is_stream_open :
521538 raise ValueError ("Underlying bidi-gRPC stream is not open" )
522539
540+ if self ._multiplexer :
541+ await self ._multiplexer .close ()
542+ self ._multiplexer = None
543+
523544 if self .read_obj_str :
524- await self .read_obj_str .close ()
545+ try :
546+ await self .read_obj_str .close ()
547+ except (asyncio .CancelledError , exceptions .GoogleAPICallError ):
548+ pass
525549 self .read_obj_str = None
526550 self ._is_stream_open = False
527551
0 commit comments