Skip to content

Commit da7ac6e

Browse files
committed
feat: implement _StreamMultiplexer register, recv loop, send, reopen, close
1 parent 0a6ea84 commit da7ac6e

File tree

2 files changed

+559
-4
lines changed

2 files changed

+559
-4
lines changed

packages/google-cloud-storage/google/cloud/storage/asyncio/_stream_multiplexer.py

Lines changed: 111 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# you may not use this file except in compliance with the License.
55
# You may obtain a copy of the License at
66
#
7-
# http://www.apache.org/licenses/LICENSE-2.0
7+
# https://www.apache.org/licenses/LICENSE-2.0
88
#
99
# Unless required by applicable law or agreed to in writing, software
1010
# distributed under the License is distributed on an "AS IS" BASIS,
@@ -46,8 +46,8 @@ class _StreamMultiplexer:
4646
"""Multiplexes concurrent download tasks over a single bidi-gRPC stream.
4747
4848
Routes responses from a background recv loop to per-task asyncio.Queues
49-
keyed by read_id. Serializes sends via lock. Coordinates stream reopening
50-
via generation-gated locking.
49+
keyed by read_id. Coordinates stream reopening via generation-gated
50+
locking.
5151
5252
A slow consumer on one task will slow down the entire shared connection
5353
due to bounded queue backpressure propagating through gRPC flow control.
@@ -61,11 +61,118 @@ def __init__(
6161
self._stream = stream
6262
self._stream_generation: int = 0
6363
self._queues: Dict[int, asyncio.Queue] = {}
64-
self._send_lock = asyncio.Lock()
6564
self._reopen_lock = asyncio.Lock()
6665
self._recv_task: Optional[asyncio.Task] = None
6766
self._queue_max_size = queue_max_size
6867

6968
@property
7069
def stream_generation(self) -> int:
7170
return self._stream_generation
71+
72+
def register(self, read_ids: Set[int]) -> asyncio.Queue:
73+
"""Register read_ids for a task and return its response queue."""
74+
queue = asyncio.Queue(maxsize=self._queue_max_size)
75+
for read_id in read_ids:
76+
self._queues[read_id] = queue
77+
return queue
78+
79+
def unregister(self, read_ids: Set[int]) -> None:
80+
"""Remove read_ids from routing. Stops recv loop if no tasks remain."""
81+
for read_id in read_ids:
82+
self._queues.pop(read_id, None)
83+
84+
def _get_unique_queues(self) -> Set[asyncio.Queue]:
85+
return set(self._queues.values())
86+
87+
def _ensure_recv_loop(self) -> None:
88+
if self._recv_task is None or self._recv_task.done():
89+
self._recv_task = asyncio.create_task(self._recv_loop())
90+
91+
def _stop_recv_loop(self) -> None:
92+
if self._recv_task and not self._recv_task.done():
93+
self._recv_task.cancel()
94+
95+
def _put_error_nowait(self, queue: asyncio.Queue, error: _StreamError) -> None:
96+
while True:
97+
try:
98+
queue.put_nowait(error)
99+
break
100+
except asyncio.QueueFull:
101+
try:
102+
queue.get_nowait()
103+
except asyncio.QueueEmpty:
104+
pass
105+
106+
async def _recv_loop(self) -> None:
107+
try:
108+
while True:
109+
response = await self._stream.recv()
110+
if response is None:
111+
sentinel = _StreamEnd()
112+
for queue in self._get_unique_queues():
113+
await queue.put(sentinel)
114+
return
115+
116+
if response.object_data_ranges:
117+
queues_to_notify: Set[asyncio.Queue] = set()
118+
for data_range in response.object_data_ranges:
119+
read_id = data_range.read_range.read_id
120+
queue = self._queues.get(read_id)
121+
if queue:
122+
queues_to_notify.add(queue)
123+
for queue in queues_to_notify:
124+
await queue.put(response)
125+
else:
126+
for queue in self._get_unique_queues():
127+
await queue.put(response)
128+
except asyncio.CancelledError:
129+
raise
130+
except Exception as e:
131+
error = _StreamError(e, self._stream_generation)
132+
for queue in self._get_unique_queues():
133+
self._put_error_nowait(queue, error)
134+
135+
async def send(self, request: _storage_v2.BidiReadObjectRequest) -> int:
136+
self._ensure_recv_loop()
137+
await self._stream.send(request)
138+
return self._stream_generation
139+
140+
async def reopen_stream(
141+
self,
142+
broken_generation: int,
143+
stream_factory: Callable[[], Awaitable[_AsyncReadObjectStream]],
144+
) -> None:
145+
async with self._reopen_lock:
146+
if self._stream_generation != broken_generation:
147+
return
148+
self._stop_recv_loop()
149+
if self._recv_task:
150+
try:
151+
await self._recv_task
152+
except (asyncio.CancelledError, Exception):
153+
pass
154+
error = _StreamError(
155+
Exception("Stream reopening"), self._stream_generation
156+
)
157+
for queue in self._get_unique_queues():
158+
self._put_error_nowait(queue, error)
159+
try:
160+
await self._stream.close()
161+
except Exception:
162+
pass
163+
self._stream = await stream_factory()
164+
self._stream_generation += 1
165+
self._ensure_recv_loop()
166+
167+
async def close(self) -> None:
168+
self._stop_recv_loop()
169+
if self._recv_task:
170+
try:
171+
await self._recv_task
172+
except (asyncio.CancelledError, Exception):
173+
pass
174+
error = _StreamError(
175+
Exception("Multiplexer closed"), self._stream_generation
176+
)
177+
for queue in self._get_unique_queues():
178+
self._put_error_nowait(queue, error)

0 commit comments

Comments
 (0)