Skip to content

Commit cf7c949

Browse files
committed
feat: add _StreamMultiplexer for asyncio bidi-gRPC streams
1 parent 14abfd5 commit cf7c949

File tree

2 files changed

+701
-0
lines changed

2 files changed

+701
-0
lines changed
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import asyncio
18+
import logging
19+
from typing import Awaitable, Callable, Dict, Optional, Set
20+
21+
from google.cloud import _storage_v2
22+
from google.cloud.storage.asyncio.async_read_object_stream import (
23+
_AsyncReadObjectStream,
24+
)
25+
26+
logger = logging.getLogger(__name__)
27+
28+
_DEFAULT_QUEUE_MAX_SIZE = 100
29+
_DEFAULT_PUT_TIMEOUT = 20.0
30+
31+
32+
class _StreamError:
33+
"""Wraps an error with the stream generation that produced it."""
34+
35+
def __init__(self, exception: Exception, generation: int):
36+
self.exception = exception
37+
self.generation = generation
38+
39+
40+
class _StreamEnd:
41+
"""Signals the stream closed normally."""
42+
43+
pass
44+
45+
46+
class _StreamMultiplexer:
47+
"""Multiplexes concurrent download tasks over a single bidi-gRPC stream.
48+
49+
Routes responses from a background recv loop to per-task asyncio.Queues
50+
keyed by read_id. Coordinates stream reopening via generation-gated
51+
locking.
52+
53+
A slow consumer on one task will slow down the entire shared connection
54+
due to bounded queue backpressure propagating through gRPC flow control.
55+
"""
56+
57+
def __init__(
58+
self,
59+
stream: _AsyncReadObjectStream,
60+
queue_max_size: int = _DEFAULT_QUEUE_MAX_SIZE,
61+
):
62+
self._stream = stream
63+
self._stream_generation: int = 0
64+
self._queues: Dict[int, asyncio.Queue] = {}
65+
self._reopen_lock = asyncio.Lock()
66+
self._recv_task: Optional[asyncio.Task] = None
67+
self._queue_max_size = queue_max_size
68+
69+
@property
70+
def stream_generation(self) -> int:
71+
return self._stream_generation
72+
73+
def register(self, read_ids: Set[int]) -> asyncio.Queue:
74+
"""Register read_ids for a task and return its response queue."""
75+
queue = asyncio.Queue(maxsize=self._queue_max_size)
76+
for read_id in read_ids:
77+
self._queues[read_id] = queue
78+
return queue
79+
80+
def unregister(self, read_ids: Set[int]) -> None:
81+
"""Remove read_ids from routing."""
82+
for read_id in read_ids:
83+
self._queues.pop(read_id, None)
84+
85+
def _get_unique_queues(self) -> Set[asyncio.Queue]:
86+
return set(self._queues.values())
87+
88+
async def _put_with_timeout(self, queue: asyncio.Queue, item) -> None:
89+
try:
90+
await asyncio.wait_for(queue.put(item), timeout=_DEFAULT_PUT_TIMEOUT)
91+
except asyncio.TimeoutError:
92+
if queue not in self._get_unique_queues():
93+
logger.debug("Dropped item for unregistered queue.")
94+
else:
95+
logger.warning(
96+
"Queue full for too long. Dropping item to prevent multiplexer hang."
97+
)
98+
99+
def _ensure_recv_loop(self) -> None:
100+
if self._recv_task is None or self._recv_task.done():
101+
self._recv_task = asyncio.create_task(self._recv_loop())
102+
103+
def _stop_recv_loop(self) -> None:
104+
if self._recv_task and not self._recv_task.done():
105+
self._recv_task.cancel()
106+
107+
def _put_error_nowait(self, queue: asyncio.Queue, error: _StreamError) -> None:
108+
while True:
109+
try:
110+
queue.put_nowait(error)
111+
break
112+
except asyncio.QueueFull:
113+
try:
114+
queue.get_nowait()
115+
except asyncio.QueueEmpty:
116+
pass
117+
118+
async def _recv_loop(self) -> None:
119+
try:
120+
while True:
121+
response = await self._stream.recv()
122+
if response is None:
123+
sentinel = _StreamEnd()
124+
await asyncio.gather(
125+
*(
126+
self._put_with_timeout(queue, sentinel)
127+
for queue in self._get_unique_queues()
128+
)
129+
)
130+
return
131+
132+
if response.object_data_ranges:
133+
queues_to_notify: Set[asyncio.Queue] = set()
134+
for data_range in response.object_data_ranges:
135+
read_id = data_range.read_range.read_id
136+
queue = self._queues.get(read_id)
137+
if queue:
138+
queues_to_notify.add(queue)
139+
await asyncio.gather(
140+
*(
141+
self._put_with_timeout(queue, response)
142+
for queue in queues_to_notify
143+
)
144+
)
145+
else:
146+
await asyncio.gather(
147+
*(
148+
self._put_with_timeout(queue, response)
149+
for queue in self._get_unique_queues()
150+
)
151+
)
152+
except asyncio.CancelledError:
153+
raise
154+
except Exception as e:
155+
error = _StreamError(e, self._stream_generation)
156+
for queue in self._get_unique_queues():
157+
self._put_error_nowait(queue, error)
158+
159+
async def send(self, request: _storage_v2.BidiReadObjectRequest) -> int:
160+
self._ensure_recv_loop()
161+
await self._stream.send(request)
162+
return self._stream_generation
163+
164+
async def reopen_stream(
165+
self,
166+
broken_generation: int,
167+
stream_factory: Callable[[], Awaitable[_AsyncReadObjectStream]],
168+
) -> None:
169+
async with self._reopen_lock:
170+
if self._stream_generation != broken_generation:
171+
return
172+
self._stop_recv_loop()
173+
if self._recv_task:
174+
try:
175+
await self._recv_task
176+
except (asyncio.CancelledError, Exception):
177+
pass
178+
error = _StreamError(Exception("Stream reopening"), self._stream_generation)
179+
for queue in self._get_unique_queues():
180+
self._put_error_nowait(queue, error)
181+
try:
182+
await self._stream.close()
183+
except Exception:
184+
pass
185+
self._stream = await stream_factory()
186+
self._stream_generation += 1
187+
self._ensure_recv_loop()
188+
189+
async def close(self) -> None:
190+
self._stop_recv_loop()
191+
if self._recv_task:
192+
try:
193+
await self._recv_task
194+
except (asyncio.CancelledError, Exception):
195+
pass
196+
error = _StreamError(Exception("Multiplexer closed"), self._stream_generation)
197+
for queue in self._get_unique_queues():
198+
self._put_error_nowait(queue, error)

0 commit comments

Comments
 (0)