44# you may not use this file except in compliance with the License.
55# You may obtain a copy of the License at
66#
7- # http ://www.apache.org/licenses/LICENSE-2.0
7+ # https ://www.apache.org/licenses/LICENSE-2.0
88#
99# Unless required by applicable law or agreed to in writing, software
1010# distributed under the License is distributed on an "AS IS" BASIS,
@@ -46,8 +46,8 @@ class _StreamMultiplexer:
4646 """Multiplexes concurrent download tasks over a single bidi-gRPC stream.
4747
4848 Routes responses from a background recv loop to per-task asyncio.Queues
49- keyed by read_id. Serializes sends via lock. Coordinates stream reopening
50- via generation-gated locking.
49+ keyed by read_id. Coordinates stream reopening via generation-gated
50+ locking.
5151
5252 A slow consumer on one task will slow down the entire shared connection
5353 due to bounded queue backpressure propagating through gRPC flow control.
@@ -61,11 +61,118 @@ def __init__(
6161 self ._stream = stream
6262 self ._stream_generation : int = 0
6363 self ._queues : Dict [int , asyncio .Queue ] = {}
64- self ._send_lock = asyncio .Lock ()
6564 self ._reopen_lock = asyncio .Lock ()
6665 self ._recv_task : Optional [asyncio .Task ] = None
6766 self ._queue_max_size = queue_max_size
6867
6968 @property
7069 def stream_generation (self ) -> int :
7170 return self ._stream_generation
71+
72+ def register (self , read_ids : Set [int ]) -> asyncio .Queue :
73+ """Register read_ids for a task and return its response queue."""
74+ queue = asyncio .Queue (maxsize = self ._queue_max_size )
75+ for read_id in read_ids :
76+ self ._queues [read_id ] = queue
77+ return queue
78+
79+ def unregister (self , read_ids : Set [int ]) -> None :
80+ """Remove read_ids from routing. Stops recv loop if no tasks remain."""
81+ for read_id in read_ids :
82+ self ._queues .pop (read_id , None )
83+
84+ def _get_unique_queues (self ) -> Set [asyncio .Queue ]:
85+ return set (self ._queues .values ())
86+
87+ def _ensure_recv_loop (self ) -> None :
88+ if self ._recv_task is None or self ._recv_task .done ():
89+ self ._recv_task = asyncio .create_task (self ._recv_loop ())
90+
91+ def _stop_recv_loop (self ) -> None :
92+ if self ._recv_task and not self ._recv_task .done ():
93+ self ._recv_task .cancel ()
94+
95+ def _put_error_nowait (self , queue : asyncio .Queue , error : _StreamError ) -> None :
96+ while True :
97+ try :
98+ queue .put_nowait (error )
99+ break
100+ except asyncio .QueueFull :
101+ try :
102+ queue .get_nowait ()
103+ except asyncio .QueueEmpty :
104+ pass
105+
106+ async def _recv_loop (self ) -> None :
107+ try :
108+ while True :
109+ response = await self ._stream .recv ()
110+ if response is None :
111+ sentinel = _StreamEnd ()
112+ for queue in self ._get_unique_queues ():
113+ await queue .put (sentinel )
114+ return
115+
116+ if response .object_data_ranges :
117+ queues_to_notify : Set [asyncio .Queue ] = set ()
118+ for data_range in response .object_data_ranges :
119+ read_id = data_range .read_range .read_id
120+ queue = self ._queues .get (read_id )
121+ if queue :
122+ queues_to_notify .add (queue )
123+ for queue in queues_to_notify :
124+ await queue .put (response )
125+ else :
126+ for queue in self ._get_unique_queues ():
127+ await queue .put (response )
128+ except asyncio .CancelledError :
129+ raise
130+ except Exception as e :
131+ error = _StreamError (e , self ._stream_generation )
132+ for queue in self ._get_unique_queues ():
133+ self ._put_error_nowait (queue , error )
134+
135+ async def send (self , request : _storage_v2 .BidiReadObjectRequest ) -> int :
136+ self ._ensure_recv_loop ()
137+ await self ._stream .send (request )
138+ return self ._stream_generation
139+
140+ async def reopen_stream (
141+ self ,
142+ broken_generation : int ,
143+ stream_factory : Callable [[], Awaitable [_AsyncReadObjectStream ]],
144+ ) -> None :
145+ async with self ._reopen_lock :
146+ if self ._stream_generation != broken_generation :
147+ return
148+ self ._stop_recv_loop ()
149+ if self ._recv_task :
150+ try :
151+ await self ._recv_task
152+ except (asyncio .CancelledError , Exception ):
153+ pass
154+ error = _StreamError (
155+ Exception ("Stream reopening" ), self ._stream_generation
156+ )
157+ for queue in self ._get_unique_queues ():
158+ self ._put_error_nowait (queue , error )
159+ try :
160+ await self ._stream .close ()
161+ except Exception :
162+ pass
163+ self ._stream = await stream_factory ()
164+ self ._stream_generation += 1
165+ self ._ensure_recv_loop ()
166+
167+ async def close (self ) -> None :
168+ self ._stop_recv_loop ()
169+ if self ._recv_task :
170+ try :
171+ await self ._recv_task
172+ except (asyncio .CancelledError , Exception ):
173+ pass
174+ error = _StreamError (
175+ Exception ("Multiplexer closed" ), self ._stream_generation
176+ )
177+ for queue in self ._get_unique_queues ():
178+ self ._put_error_nowait (queue , error )
0 commit comments