diff --git a/packages/google-cloud-storage/cloudbuild/run_zonal_tests.sh b/packages/google-cloud-storage/cloudbuild/run_zonal_tests.sh index 2d42ce6d5f68..2de098aefbc6 100644 --- a/packages/google-cloud-storage/cloudbuild/run_zonal_tests.sh +++ b/packages/google-cloud-storage/cloudbuild/run_zonal_tests.sh @@ -4,10 +4,13 @@ echo '--- Installing git and cloning repository on VM ---' sudo apt-get update && sudo apt-get install -y git python3-pip python3-venv # Clone the repository and checkout the specific commit from the build trigger. -git clone https://github.com/googleapis/python-storage.git -cd python-storage -git fetch origin "refs/pull/${_PR_NUMBER}/head" +git clone https://github.com/googleapis/google-cloud-python.git +cd google-cloud-python +if [ -n "${_PR_NUMBER:-}" ]; then + git fetch origin "refs/pull/${_PR_NUMBER}/head" +fi git checkout ${COMMIT_SHA} +cd packages/google-cloud-storage echo '--- Installing Python and dependencies on VM ---' @@ -27,4 +30,4 @@ export GCE_METADATA_MTLS_MODE=None CURRENT_ULIMIT=$(ulimit -n) echo '--- Running Zonal tests on VM with ulimit set to ---' $CURRENT_ULIMIT pytest -vv -s --log-format='%(asctime)s %(levelname)s %(message)s' --log-date-format='%H:%M:%S' tests/system/test_zonal.py -pytest -vv -s --log-format='%(asctime)s %(levelname)s %(message)s' --log-date-format='%H:%M:%S' samples/snippets/zonal_buckets/zonal_snippets_test.py +# pytest -vv -s --log-format='%(asctime)s %(levelname)s %(message)s' --log-date-format='%H:%M:%S' samples/snippets/zonal_buckets/zonal_snippets_test.py diff --git a/packages/google-cloud-storage/cloudbuild/zb-system-tests-cloudbuild.yaml b/packages/google-cloud-storage/cloudbuild/zb-system-tests-cloudbuild.yaml index 26daa8ae92d9..10de1f07ffcf 100644 --- a/packages/google-cloud-storage/cloudbuild/zb-system-tests-cloudbuild.yaml +++ b/packages/google-cloud-storage/cloudbuild/zb-system-tests-cloudbuild.yaml @@ -49,6 +49,7 @@ steps: # The VM is deleted after tests are run, regardless of success. - name: "gcr.io/google.com/cloudsdktool/cloud-sdk" id: "run-tests-and-delete-vm" + dir: "packages/google-cloud-storage" entrypoint: "bash" args: - "-c" diff --git a/packages/google-cloud-storage/google/cloud/storage/asyncio/_stream_multiplexer.py b/packages/google-cloud-storage/google/cloud/storage/asyncio/_stream_multiplexer.py new file mode 100644 index 000000000000..d8c569c6879a --- /dev/null +++ b/packages/google-cloud-storage/google/cloud/storage/asyncio/_stream_multiplexer.py @@ -0,0 +1,201 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +import logging +from typing import Awaitable, Callable, Dict, Optional, Set + +from google.cloud import _storage_v2 +from google.cloud.storage.asyncio.async_read_object_stream import ( + _AsyncReadObjectStream, +) + +logger = logging.getLogger(__name__) + +_DEFAULT_QUEUE_MAX_SIZE = 100 +_DEFAULT_PUT_TIMEOUT_SECONDS = 20.0 + + +class _StreamError: + """Wraps an error with the stream generation that produced it.""" + + def __init__(self, exception: Exception, generation: int): + self.exception = exception + self.generation = generation + + +class _StreamEnd: + """Signals the stream closed normally.""" + + pass + + +class _StreamMultiplexer: + """Multiplexes concurrent download tasks over a single bidi-gRPC stream. + + Routes responses from a background recv loop to per-task asyncio.Queues + keyed by read_id. Coordinates stream reopening via generation-gated + locking. + + A slow consumer on one task will slow down the entire shared connection + due to bounded queue backpressure propagating through gRPC flow control. + """ + + def __init__( + self, + stream: _AsyncReadObjectStream, + queue_max_size: int = _DEFAULT_QUEUE_MAX_SIZE, + ): + self._stream = stream + self._stream_generation: int = 0 + self._queues: Dict[int, asyncio.Queue] = {} + self._reopen_lock = asyncio.Lock() + self._recv_task: Optional[asyncio.Task] = None + self._queue_max_size = queue_max_size + + @property + def stream_generation(self) -> int: + return self._stream_generation + + def register(self, read_ids: Set[int]) -> asyncio.Queue: + """Register read_ids for a task and return its response queue.""" + queue = asyncio.Queue(maxsize=self._queue_max_size) + for read_id in read_ids: + self._queues[read_id] = queue + return queue + + def unregister(self, read_ids: Set[int]) -> None: + """Remove read_ids from routing.""" + for read_id in read_ids: + self._queues.pop(read_id, None) + + def _get_unique_queues(self) -> Set[asyncio.Queue]: + return set(self._queues.values()) + + async def _put_with_timeout(self, queue: asyncio.Queue, item) -> None: + try: + await asyncio.wait_for( + queue.put(item), timeout=_DEFAULT_PUT_TIMEOUT_SECONDS + ) + except asyncio.TimeoutError: + if queue not in self._get_unique_queues(): + logger.debug("Dropped item for unregistered queue.") + else: + logger.warning( + "Queue full for too long. Dropping item to prevent multiplexer hang." + ) + + def _ensure_recv_loop(self) -> None: + if self._recv_task is None or self._recv_task.done(): + self._recv_task = asyncio.create_task(self._recv_loop()) + + def _stop_recv_loop(self) -> None: + if self._recv_task and not self._recv_task.done(): + self._recv_task.cancel() + + def _put_error_nowait(self, queue: asyncio.Queue, error: _StreamError) -> None: + while True: + try: + queue.put_nowait(error) + break + except asyncio.QueueFull: + try: + queue.get_nowait() + except asyncio.QueueEmpty: + pass + + async def _recv_loop(self) -> None: + try: + while True: + response = await self._stream.recv() + if response is None: + sentinel = _StreamEnd() + await asyncio.gather( + *( + self._put_with_timeout(queue, sentinel) + for queue in self._get_unique_queues() + ) + ) + return + + if response.object_data_ranges: + queues_to_notify: Set[asyncio.Queue] = set() + for data_range in response.object_data_ranges: + read_id = data_range.read_range.read_id + queue = self._queues.get(read_id) + if queue: + queues_to_notify.add(queue) + await asyncio.gather( + *( + self._put_with_timeout(queue, response) + for queue in queues_to_notify + ) + ) + else: + await asyncio.gather( + *( + self._put_with_timeout(queue, response) + for queue in self._get_unique_queues() + ) + ) + except asyncio.CancelledError: + raise + except Exception as e: + logger.warning(f"Stream multiplexer recv loop failed: {e}", exc_info=True) + error = _StreamError(e, self._stream_generation) + for queue in self._get_unique_queues(): + self._put_error_nowait(queue, error) + + async def send(self, request: _storage_v2.BidiReadObjectRequest) -> int: + self._ensure_recv_loop() + await self._stream.send(request) + return self._stream_generation + + async def reopen_stream( + self, + broken_generation: int, + stream_factory: Callable[[], Awaitable[_AsyncReadObjectStream]], + ) -> None: + async with self._reopen_lock: + if self._stream_generation != broken_generation: + return + self._stop_recv_loop() + if self._recv_task: + try: + await self._recv_task + except (asyncio.CancelledError, Exception): + pass + error = _StreamError(Exception("Stream reopening"), self._stream_generation) + for queue in self._get_unique_queues(): + self._put_error_nowait(queue, error) + try: + await self._stream.close() + except Exception: + pass + self._stream = await stream_factory() + self._stream_generation += 1 + self._ensure_recv_loop() + + async def close(self) -> None: + self._stop_recv_loop() + if self._recv_task: + try: + await self._recv_task + except (asyncio.CancelledError, Exception): + pass + error = _StreamError(Exception("Multiplexer closed"), self._stream_generation) + for queue in self._get_unique_queues(): + self._put_error_nowait(queue, error) diff --git a/packages/google-cloud-storage/google/cloud/storage/asyncio/async_multi_range_downloader.py b/packages/google-cloud-storage/google/cloud/storage/asyncio/async_multi_range_downloader.py index cea21cb9ae66..ac0844519e2d 100644 --- a/packages/google-cloud-storage/google/cloud/storage/asyncio/async_multi_range_downloader.py +++ b/packages/google-cloud-storage/google/cloud/storage/asyncio/async_multi_range_downloader.py @@ -25,6 +25,11 @@ from google.cloud import _storage_v2 from google.cloud.storage._helpers import generate_random_56_bit_integer +from google.cloud.storage.asyncio._stream_multiplexer import ( + _StreamEnd, + _StreamError, + _StreamMultiplexer, +) from google.cloud.storage.asyncio.async_grpc_client import ( AsyncGrpcClient, ) @@ -224,9 +229,7 @@ def __init__( self.read_obj_str: Optional[_AsyncReadObjectStream] = None self._is_stream_open: bool = False self._routing_token: Optional[str] = None - self._read_id_to_writable_buffer_dict = {} - self._read_id_to_download_ranges_id = {} - self._download_ranges_id_to_pending_read_ids = {} + self._multiplexer: Optional[_StreamMultiplexer] = None self.persisted_size: Optional[int] = None # updated after opening the stream self._open_retries: int = 0 @@ -328,6 +331,45 @@ async def _do_open(): self._is_stream_open = True await retry_policy(_do_open)() + self._multiplexer = _StreamMultiplexer(self.read_obj_str) + + def _create_stream_factory(self, state, metadata): + """Create a factory that opens a new stream with current routing state.""" + + async def factory(): + current_handle = state.get("read_handle") + current_token = state.get("routing_token") + + stream = _AsyncReadObjectStream( + client=self.client.grpc_client, + bucket_name=self.bucket_name, + object_name=self.object_name, + generation_number=self.generation, + read_handle=current_handle, + ) + + current_metadata = list(metadata) if metadata else [] + if current_token: + current_metadata.append( + ( + "x-goog-request-params", + f"routing_token={current_token}", + ) + ) + + await stream.open(metadata=current_metadata if current_metadata else None) + + if stream.generation_number: + self.generation = stream.generation_number + if stream.read_handle: + self.read_handle = stream.read_handle + + self.read_obj_str = stream + self._is_stream_open = True + + return stream + + return factory async def download_ranges( self, @@ -353,32 +395,8 @@ async def download_ranges( * (0, 0, buffer) : downloads 0 to end , i.e. entire object. * (100, 0, buffer) : downloads from 100 to end. - :type lock: asyncio.Lock - :param lock: (Optional) An asyncio lock to synchronize sends and recvs - on the underlying bidi-GRPC stream. This is required when multiple - coroutines are calling this method concurrently. - - i.e. Example usage with multiple coroutines: - - ``` - lock = asyncio.Lock() - task1 = asyncio.create_task(mrd.download_ranges(ranges1, lock)) - task2 = asyncio.create_task(mrd.download_ranges(ranges2, lock)) - await asyncio.gather(task1, task2) - - ``` - - If user want to call this method serially from multiple coroutines, - then providing a lock is not necessary. - - ``` - await mrd.download_ranges(ranges1) - await mrd.download_ranges(ranges2) - - # ... some other code code... - - ``` + :param lock: (Deprecated) This parameter is deprecated and has no effect. :type retry_policy: :class:`~google.api_core.retry_async.AsyncRetry` :param retry_policy: (Optional) The retry policy to use for the operation. @@ -397,9 +415,6 @@ async def download_ranges( if not self._is_stream_open: raise ValueError("Underlying bidi-gRPC stream is not open") - if lock is None: - lock = asyncio.Lock() - if retry_policy is None: retry_policy = AsyncRetry(predicate=_is_read_retryable) @@ -419,99 +434,97 @@ async def download_ranges( "routing_token": None, } - # Track attempts to manage stream reuse - attempt_count = 0 - - def send_ranges_and_get_bytes( - requests: List[_storage_v2.ReadRange], - state: Dict[str, Any], - metadata: Optional[List[Tuple[str, str]]] = None, - ): - async def generator(): - nonlocal attempt_count - attempt_count += 1 - - if attempt_count > 1: - logger.info( - f"Resuming download (attempt {attempt_count}) for {len(requests)} ranges." - ) + read_ids = set(download_states.keys()) + queue = self._multiplexer.register(read_ids) - async with lock: - current_handle = state.get("read_handle") - current_token = state.get("routing_token") + try: + attempt_count = 0 + last_broken_generation = None - # We reopen if it's a redirect (token exists) OR if this is a retry - # (not first attempt). This prevents trying to send data on a dead - # stream from a previous failed attempt. - should_reopen = ( - (attempt_count > 1) - or (current_token is not None) - or (metadata is not None) - ) + def send_and_recv_via_multiplexer( + requests: List[_storage_v2.ReadRange], + state: Dict[str, Any], + ): + async def generator(): + nonlocal attempt_count, last_broken_generation + attempt_count += 1 - if should_reopen: - if current_token: - logger.info( - f"Re-opening stream with routing token: {current_token}" - ) - - self.read_obj_str = _AsyncReadObjectStream( - client=self.client.grpc_client, - bucket_name=self.bucket_name, - object_name=self.object_name, - generation_number=self.generation, - read_handle=current_handle, + if attempt_count > 1: + logger.info( + f"Resuming download (attempt {attempt_count}) for {len(requests)} ranges." ) - # Inject routing_token into metadata if present - current_metadata = list(metadata) if metadata else [] - if current_token: - current_metadata.append( - ( - "x-goog-request-params", - f"routing_token={current_token}", - ) - ) - - await self.read_obj_str.open( - metadata=current_metadata if current_metadata else None + # Reopen stream if needed + should_reopen = ( + attempt_count > 1 and last_broken_generation is not None + ) or (attempt_count == 1 and metadata is not None) + if should_reopen: + broken_gen = ( + last_broken_generation + if attempt_count > 1 + else self._multiplexer.stream_generation + ) + stream_factory = self._create_stream_factory(state, metadata) + await self._multiplexer.reopen_stream( + broken_gen, stream_factory ) - self._is_stream_open = True - pending_read_ids = {r.read_id for r in requests} + stream_generation = self._multiplexer.stream_generation # Send Requests + pending_read_ids = {r.read_id for r in requests} for i in range( 0, len(requests), _MAX_READ_RANGES_PER_BIDI_READ_REQUEST ): batch = requests[i : i + _MAX_READ_RANGES_PER_BIDI_READ_REQUEST] - await self.read_obj_str.send( - _storage_v2.BidiReadObjectRequest(read_ranges=batch) - ) + try: + await self._multiplexer.send( + _storage_v2.BidiReadObjectRequest(read_ranges=batch) + ) + except Exception: + last_broken_generation = stream_generation + raise + # Receive Responses while pending_read_ids: - response = await self.read_obj_str.recv() - if response is None: + item = await queue.get() + + if isinstance(item, _StreamEnd): + if pending_read_ids: + last_broken_generation = stream_generation + raise exceptions.ServiceUnavailable( + "Stream ended with pending read_ids" + ) break - if response.object_data_ranges: - for data_range in response.object_data_ranges: + + if isinstance(item, _StreamError): + if item.generation < stream_generation: + continue # stale error, skip + last_broken_generation = item.generation + raise item.exception + + # Track completion + if item.object_data_ranges: + for data_range in item.object_data_ranges: if data_range.range_end: pending_read_ids.discard( data_range.read_range.read_id ) - yield response + yield item - return generator() + return generator() - strategy = _ReadResumptionStrategy() - retry_manager = _BidiStreamRetryManager( - strategy, lambda r, s: send_ranges_and_get_bytes(r, s, metadata=metadata) - ) + strategy = _ReadResumptionStrategy() + retry_manager = _BidiStreamRetryManager( + strategy, send_and_recv_via_multiplexer + ) - await retry_manager.execute(initial_state, retry_policy) + await retry_manager.execute(initial_state, retry_policy) - if initial_state.get("read_handle"): - self.read_handle = initial_state["read_handle"] + if initial_state.get("read_handle"): + self.read_handle = initial_state["read_handle"] + finally: + self._multiplexer.unregister(read_ids) async def close(self): """ @@ -520,8 +533,15 @@ async def close(self): if not self._is_stream_open: raise ValueError("Underlying bidi-gRPC stream is not open") + if self._multiplexer: + await self._multiplexer.close() + self._multiplexer = None + if self.read_obj_str: - await self.read_obj_str.close() + try: + await self.read_obj_str.close() + except (asyncio.CancelledError, exceptions.GoogleAPICallError): + pass self.read_obj_str = None self._is_stream_open = False diff --git a/packages/google-cloud-storage/tests/system/test_zonal.py b/packages/google-cloud-storage/tests/system/test_zonal.py index edd323b037ec..3c275794e331 100644 --- a/packages/google-cloud-storage/tests/system/test_zonal.py +++ b/packages/google-cloud-storage/tests/system/test_zonal.py @@ -2,13 +2,14 @@ import asyncio import gc import os +import random import uuid from io import BytesIO # python additional imports import google_crc32c import pytest -from google.api_core.exceptions import FailedPrecondition, NotFound +from google.api_core.exceptions import FailedPrecondition, NotFound, OutOfRange from google.cloud.storage.asyncio.async_appendable_object_writer import ( _DEFAULT_FLUSH_INTERVAL_BYTES, @@ -594,3 +595,190 @@ async def _run(): gc.collect() event_loop.run_until_complete(_run()) + + +@pytest.mark.parametrize( + "ranges_desc, chunk_ranges", + [ + ("small", [(1, 100)] * 3), + ("medium", [(100, 100000)] * 3), + ("large", [(1000000, 2000000)] * 3), + ("mixed", [(1, 100), (100, 100000), (1000000, 2000000)]), + ], +) +def test_mrd_concurrent_download( + storage_client, + blobs_to_delete, + event_loop, + grpc_client_direct, + ranges_desc, + chunk_ranges, +): + """ + Test that mrd can handle concurrent `download_ranges` calls correctly. + Tests overlapping ranges, minimal concurrency, + parametrized chunk sizes (small/medium/large/mixed), and full object fetching alongside specific chunks. + """ + object_size = 15 * 1024 * 1024 # 15MB + object_name = f"test_mrd_concurrent-{uuid.uuid4()}" + + async def _run(): + object_data = os.urandom(object_size) + + writer = AsyncAppendableObjectWriter( + grpc_client_direct, _ZONAL_BUCKET, object_name + ) + await writer.open() + await writer.append(object_data) + await writer.close(finalize_on_close=True) + + async with AsyncMultiRangeDownloader( + grpc_client_direct, _ZONAL_BUCKET, object_name + ) as mrd: + tasks = [] + ranges_to_fetch = [] + + for min_len, max_len in chunk_ranges: + start = random.randint(0, object_size - max_len) + length = random.randint(min_len, max_len) + ranges_to_fetch.append((start, length)) + + # Full object fetching concurrently + ranges_to_fetch.append((0, 0)) + + random.shuffle(ranges_to_fetch) + + buffers = [BytesIO() for _ in range(len(ranges_to_fetch))] + + for idx, (start, length) in enumerate(ranges_to_fetch): + tasks.append( + asyncio.create_task( + mrd.download_ranges([(start, length, buffers[idx])]) + ) + ) + + await asyncio.gather(*tasks) + + # Validation + for idx, (start, length) in enumerate(ranges_to_fetch): + if length == 0: + expected_data = object_data[start:] + else: + expected_data = object_data[start : start + length] + assert buffers[idx].getvalue() == expected_data + + del writer + gc.collect() + blobs_to_delete.append(storage_client.bucket(_ZONAL_BUCKET).blob(object_name)) + + event_loop.run_until_complete(_run()) + + +def test_mrd_concurrent_download_cancellation( + storage_client, blobs_to_delete, event_loop, grpc_client_direct +): + """ + Test task cancellation / abort mid-stream. + Tests that downloading gracefully manages memory and internal references + when tasks are canceled during active multiplexing, without breaking remaining downloads. + """ + object_size = 5 * 1024 * 1024 # 5MB + object_name = f"test_mrd_cancel-{uuid.uuid4()}" + + async def _run(): + object_data = os.urandom(object_size) + + writer = AsyncAppendableObjectWriter( + grpc_client_direct, _ZONAL_BUCKET, object_name + ) + await writer.open() + await writer.append(object_data) + await writer.close(finalize_on_close=True) + + async with AsyncMultiRangeDownloader( + grpc_client_direct, _ZONAL_BUCKET, object_name + ) as mrd: + tasks = [] + num_chunks = 100 + chunk_size = object_size // num_chunks + buffers = [BytesIO() for _ in range(num_chunks)] + + for i in range(num_chunks): + start = i * chunk_size + tasks.append( + asyncio.create_task( + mrd.download_ranges([(start, chunk_size, buffers[i])]) + ) + ) + + # Let the loop start sending Bidi requests + await asyncio.sleep(0.01) + + # Cancel a subset of evenly distributed tasks + for i in range(0, num_chunks, 2): + tasks[i].cancel() + + results = await asyncio.gather(*tasks, return_exceptions=True) + + for i in range(num_chunks): + if i % 2 == 0: + assert isinstance(results[i], asyncio.CancelledError) + else: + start = i * chunk_size + expected_data = object_data[start : start + chunk_size] + assert buffers[i].getvalue() == expected_data + + del writer + gc.collect() + blobs_to_delete.append(storage_client.bucket(_ZONAL_BUCKET).blob(object_name)) + + event_loop.run_until_complete(_run()) + + +def test_mrd_concurrent_download_out_of_bounds( + storage_client, blobs_to_delete, event_loop, grpc_client_direct +): + """ + Test out-of-bounds & edge ranges concurrent with valid requests. + Verifies isolation: invalid bounds generate correct exceptions and don't stall the stream + for concurrently valid requests. + """ + object_size = 2 * 1024 * 1024 # 2MB + object_name = f"test_mrd_oob-{uuid.uuid4()}" + + async def _run(): + object_data = os.urandom(object_size) + + writer = AsyncAppendableObjectWriter( + grpc_client_direct, _ZONAL_BUCKET, object_name + ) + await writer.open() + await writer.append(object_data) + await writer.close(finalize_on_close=True) + + async with AsyncMultiRangeDownloader( + grpc_client_direct, _ZONAL_BUCKET, object_name + ) as mrd: + valid_buffer = BytesIO() + valid_task = asyncio.create_task( + mrd.download_ranges([(0, 100, valid_buffer)]) + ) + + oob_buffer = BytesIO() + oob_task = asyncio.create_task( + mrd.download_ranges([(object_size + 1000, 100, oob_buffer)]) + ) + + results = await asyncio.gather(valid_task, oob_task, return_exceptions=True) + + # Verify valid one processed correctly + assert valid_buffer.getvalue() == object_data[:100] + + # Verify fully OOB request returned Exception + assert isinstance(results[1], OutOfRange) + + del writer + gc.collect() + blobs_to_delete.append(storage_client.bucket(_ZONAL_BUCKET).blob(object_name)) + + event_loop.run_until_complete(_run()) diff --git a/packages/google-cloud-storage/tests/unit/asyncio/test_async_multi_range_downloader.py b/packages/google-cloud-storage/tests/unit/asyncio/test_async_multi_range_downloader.py index 80df5a438173..1ac17ad3cc8e 100644 --- a/packages/google-cloud-storage/tests/unit/asyncio/test_async_multi_range_downloader.py +++ b/packages/google-cloud-storage/tests/unit/asyncio/test_async_multi_range_downloader.py @@ -105,7 +105,6 @@ async def test_create_mrd(self, mock_cls_async_read_object_stream): async def test_download_ranges_via_async_gather( self, mock_cls_async_read_object_stream, mock_random_int ): - # Arrange data = b"these_are_18_chars" crc32c = Checksum(data).digest() crc32c_int = int.from_bytes(crc32c, "big") @@ -114,55 +113,65 @@ async def test_download_ranges_via_async_gather( ) mock_mrd, _ = await self._make_mock_mrd(mock_cls_async_read_object_stream) - mock_random_int.side_effect = [456, 91011] - mock_mrd.read_obj_str.send = AsyncMock() - mock_mrd.read_obj_str.recv = AsyncMock() + send_count = 0 + both_sent = asyncio.Event() + + async def counting_send(request): + nonlocal send_count + send_count += 1 + if send_count >= 2: + both_sent.set() + + mock_mrd.read_obj_str.send = AsyncMock(side_effect=counting_send) + + recv_call_count = 0 + + async def controlled_recv(): + nonlocal recv_call_count + recv_call_count += 1 + if recv_call_count == 1: + await both_sent.wait() + return _storage_v2.BidiReadObjectResponse( + object_data_ranges=[ + _storage_v2.ObjectRangeData( + checksummed_data=_storage_v2.ChecksummedData( + content=data, crc32c=crc32c_int + ), + range_end=True, + read_range=_storage_v2.ReadRange( + read_offset=0, read_length=18, read_id=456 + ), + ) + ] + ) + elif recv_call_count == 2: + return _storage_v2.BidiReadObjectResponse( + object_data_ranges=[ + _storage_v2.ObjectRangeData( + checksummed_data=_storage_v2.ChecksummedData( + content=data[10:16], + crc32c=crc32c_checksum_for_data_slice, + ), + range_end=True, + read_range=_storage_v2.ReadRange( + read_offset=10, read_length=6, read_id=91011 + ), + ) + ], + ) + return None - mock_mrd.read_obj_str.recv.side_effect = [ - _storage_v2.BidiReadObjectResponse( - object_data_ranges=[ - _storage_v2.ObjectRangeData( - checksummed_data=_storage_v2.ChecksummedData( - content=data, crc32c=crc32c_int - ), - range_end=True, - read_range=_storage_v2.ReadRange( - read_offset=0, read_length=18, read_id=456 - ), - ) - ] - ), - _storage_v2.BidiReadObjectResponse( - object_data_ranges=[ - _storage_v2.ObjectRangeData( - checksummed_data=_storage_v2.ChecksummedData( - content=data[10:16], - crc32c=crc32c_checksum_for_data_slice, - ), - range_end=True, - read_range=_storage_v2.ReadRange( - read_offset=10, read_length=6, read_id=91011 - ), - ) - ], - ), - None, - ] + mock_mrd.read_obj_str.recv = AsyncMock(side_effect=controlled_recv) - # Act buffer = BytesIO() second_buffer = BytesIO() - lock = asyncio.Lock() - task1 = asyncio.create_task(mock_mrd.download_ranges([(0, 18, buffer)], lock)) - task2 = asyncio.create_task( - mock_mrd.download_ranges([(10, 6, second_buffer)], lock) - ) + task1 = asyncio.create_task(mock_mrd.download_ranges([(0, 18, buffer)])) + task2 = asyncio.create_task(mock_mrd.download_ranges([(10, 6, second_buffer)])) await asyncio.gather(task1, task2) - # Assert assert buffer.getvalue() == data assert second_buffer.getvalue() == data[10:16] @@ -182,7 +191,6 @@ async def test_download_ranges( crc32c_int = int.from_bytes(crc32c, "big") mock_mrd, _ = await self._make_mock_mrd(mock_cls_async_read_object_stream) - mock_random_int.side_effect = [456] mock_mrd.read_obj_str.send = AsyncMock() @@ -203,9 +211,9 @@ async def test_download_ranges( ), None, ] - # Act buffer = BytesIO() + await mock_mrd.download_ranges([(0, 18, buffer)]) # Assert @@ -320,6 +328,7 @@ def test_init_raises_if_crc32c_c_extension_is_missing(self, mock_google_crc32c): async def test_download_ranges_raises_on_checksum_mismatch( self, mock_checksum_class ): + from google.cloud.storage.asyncio._stream_multiplexer import _StreamMultiplexer from google.cloud.storage.asyncio.async_multi_range_downloader import ( AsyncMultiRangeDownloader, ) @@ -353,6 +362,7 @@ async def test_download_ranges_raises_on_checksum_mismatch( mrd = AsyncMultiRangeDownloader(mock_client, "bucket", "object") mrd.read_obj_str = mock_stream mrd._is_stream_open = True + mrd._multiplexer = _StreamMultiplexer(mock_stream) with pytest.raises(DataCorruption) as exc_info: with mock.patch( @@ -431,6 +441,8 @@ async def test_create_mrd_with_generation_number( # Assert assert mrd.generation == _TEST_GENERATION_NUMBER + assert mrd.read_handle == _TEST_READ_HANDLE + assert mrd.persisted_size == _TEST_OBJECT_SIZE assert "'generation_number' is deprecated" in caplog.text @pytest.mark.asyncio @@ -524,41 +536,45 @@ async def test_download_ranges_resumption_logging( # Arrange mock_mrd, _ = await self._make_mock_mrd(mock_cls_async_read_object_stream) - mock_mrd.read_obj_str.send = AsyncMock() - mock_mrd.read_obj_str.recv = AsyncMock() - from google.api_core import exceptions as core_exceptions retryable_exc = core_exceptions.ServiceUnavailable("Retry me") - # mock send to raise exception ONCE then succeed - mock_mrd.read_obj_str.send.side_effect = [ - retryable_exc, - None, # Success on second try - ] + mock_mrd.read_obj_str.send = AsyncMock( + side_effect=[ + retryable_exc, + None, + ] + ) - # mock recv for second try - mock_mrd.read_obj_str.recv.side_effect = [ - _storage_v2.BidiReadObjectResponse( - object_data_ranges=[ - _storage_v2.ObjectRangeData( - checksummed_data=_storage_v2.ChecksummedData( - content=b"data", crc32c=123 - ), - range_end=True, - read_range=_storage_v2.ReadRange( - read_offset=0, read_length=4, read_id=123 - ), - ) - ] - ), - None, - ] + recv_call_count = 0 + + async def staged_recv(): + nonlocal recv_call_count + recv_call_count += 1 + if recv_call_count == 1: + return _storage_v2.BidiReadObjectResponse( + object_data_ranges=[ + _storage_v2.ObjectRangeData( + checksummed_data=_storage_v2.ChecksummedData( + content=b"data", crc32c=123 + ), + range_end=True, + read_range=_storage_v2.ReadRange( + read_offset=0, read_length=4, read_id=123 + ), + ) + ] + ) + return None + + mock_mrd.read_obj_str.recv = AsyncMock(side_effect=staged_recv) mock_random_int.return_value = 123 # Act buffer = BytesIO() + # Patch Checksum where it is likely used (reads_resumption_strategy or similar), # but actually if we use google_crc32c directly, we should patch that or provide valid CRC. # Since we can't reliably predict where Checksum is imported/used without more digging, @@ -567,12 +583,10 @@ async def test_download_ranges_resumption_logging( # But we can't force b"data" to have crc=123. # So we MUST patch Checksum. # It is used in google.cloud.storage.asyncio.retry.reads_resumption_strategy - with mock.patch( "google.cloud.storage.asyncio.retry.reads_resumption_strategy.Checksum" ) as mock_chk: mock_chk.return_value.digest.return_value = (123).to_bytes(4, "big") - await mock_mrd.download_ranges([(0, 4, buffer)]) # Assert diff --git a/packages/google-cloud-storage/tests/unit/asyncio/test_stream_multiplexer.py b/packages/google-cloud-storage/tests/unit/asyncio/test_stream_multiplexer.py new file mode 100644 index 000000000000..e2a15b482148 --- /dev/null +++ b/packages/google-cloud-storage/tests/unit/asyncio/test_stream_multiplexer.py @@ -0,0 +1,597 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from unittest.mock import AsyncMock + +import pytest + +from google.cloud import _storage_v2 +from google.cloud.storage.asyncio._stream_multiplexer import ( + _DEFAULT_QUEUE_MAX_SIZE, + _StreamEnd, + _StreamError, + _StreamMultiplexer, +) + + +class TestSentinelTypes: + def test_stream_error_stores_exception_and_generation(self): + # Given an exception and a generation + exc = ValueError("test") + + # When a StreamError is created + error = _StreamError(exc, generation=3) + + # Then it stores the exception and generation + assert error.exception is exc + assert error.generation == 3 + + +class TestStreamMultiplexerInit: + def test_init_sets_stream_and_defaults(self): + # Given a mock stream + mock_stream = AsyncMock() + + # When a multiplexer is created + mux = _StreamMultiplexer(mock_stream) + + # Then it sets the stream and defaults + assert mux._stream is mock_stream + assert mux.stream_generation == 0 + assert mux._queues == {} + assert mux._recv_task is None + assert mux._queue_max_size == _DEFAULT_QUEUE_MAX_SIZE + + def test_init_custom_queue_size(self): + # Given a mock stream + mock_stream = AsyncMock() + + # When a multiplexer is created with a custom queue size + mux = _StreamMultiplexer(mock_stream, queue_max_size=50) + + # Then it sets the custom queue size + assert mux._queue_max_size == 50 + + +def _make_response(read_id, data=b"data", range_end=False): + return _storage_v2.BidiReadObjectResponse( + object_data_ranges=[ + _storage_v2.ObjectRangeData( + checksummed_data=_storage_v2.ChecksummedData(content=data), + read_range=_storage_v2.ReadRange( + read_id=read_id, read_offset=0, read_length=len(data) + ), + range_end=range_end, + ) + ] + ) + + +def _make_multi_range_response(read_ids, data=b"data"): + ranges = [] + for rid in read_ids: + ranges.append( + _storage_v2.ObjectRangeData( + checksummed_data=_storage_v2.ChecksummedData(content=data), + read_range=_storage_v2.ReadRange( + read_id=rid, read_offset=0, read_length=len(data) + ), + ) + ) + return _storage_v2.BidiReadObjectResponse(object_data_ranges=ranges) + + +class TestRegisterUnregister: + def _make_multiplexer(self): + mock_stream = AsyncMock() + mock_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + return _StreamMultiplexer(mock_stream), mock_stream + + @pytest.mark.asyncio + async def test_register_returns_bounded_queue(self): + # Given a multiplexer + mux, _ = self._make_multiplexer() + + # When registering read IDs + queue = mux.register({1, 2, 3}) + + # Then a bounded queue is returned + assert isinstance(queue, asyncio.Queue) + assert queue.maxsize == _DEFAULT_QUEUE_MAX_SIZE + + @pytest.mark.asyncio + async def test_register_maps_read_ids_to_same_queue(self): + # Given a multiplexer + mux, _ = self._make_multiplexer() + + # When registering multiple read IDs + queue = mux.register({10, 20}) + + # Then they map to the same queue + assert mux._queues[10] is queue + assert mux._queues[20] is queue + + @pytest.mark.asyncio + async def test_register_does_not_start_recv_loop(self): + # Given a multiplexer + mux, _ = self._make_multiplexer() + + # When registering a read ID + mux.register({1}) + + # Then the receive loop is not started + assert mux._recv_task is None + + @pytest.mark.asyncio + async def test_two_registers_get_separate_queues(self): + # Given a multiplexer + mux, _ = self._make_multiplexer() + + # When registering different read IDs separately + q1 = mux.register({1}) + q2 = mux.register({2}) + + # Then separate queues are returned + assert q1 is not q2 + assert mux._queues[1] is q1 + assert mux._queues[2] is q2 + + @pytest.mark.asyncio + async def test_unregister_removes_read_ids(self): + # Given a multiplexer with registered read IDs + mux, _ = self._make_multiplexer() + mux.register({1, 2}) + + # When unregistering a read ID + mux.unregister({1}) + + # Then it is removed from the mapping + assert 1 not in mux._queues + assert 2 in mux._queues + + @pytest.mark.asyncio + async def test_unregister_all_does_not_stop_recv_loop(self): + # Given a multiplexer with an active receive loop + mux, _ = self._make_multiplexer() + mux.register({1}) + mux._ensure_recv_loop() + recv_task = mux._recv_task + + # When unregistering the read ID + mux.unregister({1}) + + # Then the receive loop is not cancelled + await asyncio.sleep(0) + assert not recv_task.cancelled() + + @pytest.mark.asyncio + async def test_unregister_nonexistent_is_noop(self): + # Given a multiplexer with a registered read ID + mux, _ = self._make_multiplexer() + mux.register({1}) + + # When unregistering a non-existent read ID + mux.unregister({999}) + + # Then the existing registration remains + assert 1 in mux._queues + + +class TestRecvLoop: + @pytest.mark.asyncio + async def test_routes_response_by_read_id(self): + # Given a multiplexer with registered queues for read IDs 10 and 20 + mock_stream = AsyncMock() + resp1 = _make_response(read_id=10, data=b"hello") + resp2 = _make_response(read_id=20, data=b"world") + mock_stream.recv = AsyncMock(side_effect=[resp1, resp2, None]) + + mux = _StreamMultiplexer(mock_stream) + q1 = mux.register({10}) + q2 = mux.register({20}) + + # When the receive loop is started + mux._ensure_recv_loop() + + # Then responses are routed to the corresponding queues and stream ends are sent + item1 = await asyncio.wait_for(q1.get(), timeout=1) + item2 = await asyncio.wait_for(q2.get(), timeout=1) + + assert item1 is resp1 + assert item2 is resp2 + + end1 = await asyncio.wait_for(q1.get(), timeout=1) + end2 = await asyncio.wait_for(q2.get(), timeout=1) + assert isinstance(end1, _StreamEnd) + assert isinstance(end2, _StreamEnd) + + @pytest.mark.asyncio + async def test_deduplicates_when_multiple_read_ids_map_to_same_queue(self): + # Given a multiplexer with multiple read IDs mapped to the same queue + mock_stream = AsyncMock() + resp = _make_multi_range_response([10, 11]) + mock_stream.recv = AsyncMock(side_effect=[resp, None]) + + mux = _StreamMultiplexer(mock_stream) + queue = mux.register({10, 11}) + + # When the receive loop is started + mux._ensure_recv_loop() + + # Then the response is put into the queue only once + item = await asyncio.wait_for(queue.get(), timeout=1) + assert item is resp + + end = await asyncio.wait_for(queue.get(), timeout=1) + assert isinstance(end, _StreamEnd) + + @pytest.mark.asyncio + async def test_metadata_only_response_broadcast_to_all(self): + # Given a multiplexer with multiple registered queues + mock_stream = AsyncMock() + metadata_resp = _storage_v2.BidiReadObjectResponse( + read_handle=_storage_v2.BidiReadHandle(handle=b"handle") + ) + mock_stream.recv = AsyncMock(side_effect=[metadata_resp, None]) + + mux = _StreamMultiplexer(mock_stream) + q1 = mux.register({10}) + q2 = mux.register({20}) + + # When the receive loop is started + mux._ensure_recv_loop() + + # Then the metadata-only response is broadcast to all queues + item1 = await asyncio.wait_for(q1.get(), timeout=1) + item2 = await asyncio.wait_for(q2.get(), timeout=1) + assert item1 is metadata_resp + assert item2 is metadata_resp + + @pytest.mark.asyncio + async def test_stream_end_sends_sentinel_to_all_queues(self): + # Given a multiplexer with multiple registered queues and a stream that ends immediately + mock_stream = AsyncMock() + mock_stream.recv = AsyncMock(return_value=None) + + mux = _StreamMultiplexer(mock_stream) + q1 = mux.register({10}) + q2 = mux.register({20}) + + # When the receive loop is started + mux._ensure_recv_loop() + + # Then a StreamEnd sentinel is sent to all queues + end1 = await asyncio.wait_for(q1.get(), timeout=1) + end2 = await asyncio.wait_for(q2.get(), timeout=1) + assert isinstance(end1, _StreamEnd) + assert isinstance(end2, _StreamEnd) + + @pytest.mark.asyncio + async def test_error_broadcasts_stream_error_to_all_queues(self): + # Given a multiplexer with multiple registered queues and a stream that raises an error + mock_stream = AsyncMock() + exc = RuntimeError("stream broke") + mock_stream.recv = AsyncMock(side_effect=exc) + + mux = _StreamMultiplexer(mock_stream) + q1 = mux.register({10}) + q2 = mux.register({20}) + + # When the receive loop is started + mux._ensure_recv_loop() + await asyncio.sleep(0.05) + + # Then a StreamError is broadcast to all queues + err1 = q1.get_nowait() + err2 = q2.get_nowait() + assert isinstance(err1, _StreamError) + assert err1.exception is exc + assert err1.generation == 0 + assert isinstance(err2, _StreamError) + assert err2.exception is exc + + @pytest.mark.asyncio + async def test_error_uses_put_nowait(self): + # Given a multiplexer with a full queue and a stream that raises an error + mock_stream = AsyncMock() + exc = RuntimeError("broke") + mock_stream.recv = AsyncMock(side_effect=exc) + + mux = _StreamMultiplexer(mock_stream, queue_max_size=1) + queue = mux.register({10}) + queue.put_nowait("filler") + + # When the receive loop is started + mux._ensure_recv_loop() + await asyncio.sleep(0.05) + + # Then the error is recorded even if the queue was full + assert queue.qsize() == 1 + err = queue.get_nowait() + assert isinstance(err, _StreamError) + assert err.exception is exc + + @pytest.mark.asyncio + async def test_unknown_read_id_is_dropped(self): + # Given a multiplexer and a response with an unknown read ID + mock_stream = AsyncMock() + resp = _make_response(read_id=999) + mock_stream.recv = AsyncMock(side_effect=[resp, None]) + + mux = _StreamMultiplexer(mock_stream) + queue = mux.register({10}) + + # When the receive loop is started + mux._ensure_recv_loop() + + # Then the response is dropped and only StreamEnd is received + end = await asyncio.wait_for(queue.get(), timeout=1) + assert isinstance(end, _StreamEnd) + + +class TestSend: + @pytest.mark.asyncio + async def test_send_forwards_to_stream(self): + # Given a multiplexer and a request + mock_stream = AsyncMock() + mock_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + mux = _StreamMultiplexer(mock_stream) + request = _storage_v2.BidiReadObjectRequest( + read_ranges=[ + _storage_v2.ReadRange(read_id=1, read_offset=0, read_length=10) + ] + ) + + # When sending the request + gen = await mux.send(request) + + # Then it is forwarded to the stream and current generation is returned + mock_stream.send.assert_called_once_with(request) + assert gen == 0 + + @pytest.mark.asyncio + async def test_send_returns_current_generation(self): + # Given a multiplexer at generation 5 + mock_stream = AsyncMock() + mock_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + mux = _StreamMultiplexer(mock_stream) + mux._stream_generation = 5 + request = _storage_v2.BidiReadObjectRequest() + + # When sending a request + gen = await mux.send(request) + + # Then it returns the current generation + assert gen == 5 + + @pytest.mark.asyncio + async def test_send_propagates_exception(self): + # Given a multiplexer where send fails + mock_stream = AsyncMock() + mock_stream.send = AsyncMock(side_effect=RuntimeError("send failed")) + mock_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + mux = _StreamMultiplexer(mock_stream) + + # When sending a request + # Then the exception is propagated + with pytest.raises(RuntimeError, match="send failed"): + await mux.send(_storage_v2.BidiReadObjectRequest()) + + +class TestReopenStream: + @pytest.mark.asyncio + async def test_reopen_bumps_generation_and_replaces_stream(self): + # Given a multiplexer with a registered queue + old_stream = AsyncMock() + old_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + mux = _StreamMultiplexer(old_stream) + mux.register({1}) + + new_stream = AsyncMock() + new_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + factory = AsyncMock(return_value=new_stream) + + # When the stream is reopened with the correct generation + await mux.reopen_stream(0, factory) + + # Then the generation is bumped and the stream is replaced + assert mux.stream_generation == 1 + assert mux._stream is new_stream + factory.assert_called_once() + + @pytest.mark.asyncio + async def test_reopen_skips_if_generation_mismatch(self): + # Given a multiplexer at generation 5 + mock_stream = AsyncMock() + mock_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + mux = _StreamMultiplexer(mock_stream) + mux._stream_generation = 5 + mux.register({1}) + + factory = AsyncMock() + + # When reopen is called with a mismatched generation (3) + await mux.reopen_stream(3, factory) + + # Then the reopen is skipped and generation remains unchanged + assert mux.stream_generation == 5 + factory.assert_not_called() + + @pytest.mark.asyncio + async def test_reopen_broadcasts_error_before_bump(self): + # Given a multiplexer with a registered queue + old_stream = AsyncMock() + old_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + mux = _StreamMultiplexer(old_stream) + queue = mux.register({1}) + + new_stream = AsyncMock() + new_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + factory = AsyncMock(return_value=new_stream) + + # When the stream is reopened + await mux.reopen_stream(0, factory) + + # Then a StreamError is broadcast to the queue before the bump + err = queue.get_nowait() + assert isinstance(err, _StreamError) + assert err.generation == 0 + + @pytest.mark.asyncio + async def test_reopen_starts_new_recv_loop(self): + # Given a multiplexer with a registered queue and an active recv task + old_stream = AsyncMock() + old_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + mux = _StreamMultiplexer(old_stream) + mux.register({1}) + old_recv_task = mux._recv_task + + new_stream = AsyncMock() + new_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + factory = AsyncMock(return_value=new_stream) + + # When the stream is reopened + await mux.reopen_stream(0, factory) + + # Then a new receive loop task is started + assert mux._recv_task is not old_recv_task + assert not mux._recv_task.done() + + @pytest.mark.asyncio + async def test_reopen_closes_old_stream_best_effort(self): + # Given a multiplexer where closing the old stream raises an error + old_stream = AsyncMock() + old_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + old_stream.close = AsyncMock(side_effect=RuntimeError("close failed")) + mux = _StreamMultiplexer(old_stream) + mux.register({1}) + + new_stream = AsyncMock() + new_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + factory = AsyncMock(return_value=new_stream) + + # When the stream is reopened + await mux.reopen_stream(0, factory) + + # Then the reopen still succeeds + assert mux.stream_generation == 1 + + @pytest.mark.asyncio + async def test_concurrent_reopen_only_one_wins(self): + # Given a multiplexer and a counting factory + old_stream = AsyncMock() + old_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + mux = _StreamMultiplexer(old_stream) + mux.register({1}) + + call_count = 0 + + async def counting_factory(): + nonlocal call_count + call_count += 1 + new_stream = AsyncMock() + new_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + return new_stream + + # When concurrent reopen calls are made + await asyncio.gather( + mux.reopen_stream(0, counting_factory), + mux.reopen_stream(0, counting_factory), + ) + + # Then only one factory call is made and generation is bumped once + assert call_count == 1 + assert mux.stream_generation == 1 + + @pytest.mark.asyncio + async def test_reopen_factory_failure_leaves_generation_unchanged(self): + """If stream_factory raises, generation is not bumped and recv loop + is not restarted. The caller's retry manager will re-attempt reopen + with the same generation, which will succeed because the generation + check still matches.""" + # Given a multiplexer and a failing factory + old_stream = AsyncMock() + old_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + mux = _StreamMultiplexer(old_stream) + mux.register({1}) + + failing_factory = AsyncMock(side_effect=RuntimeError("open failed")) + + # When reopen fails + with pytest.raises(RuntimeError, match="open failed"): + await mux.reopen_stream(0, failing_factory) + + # Then generation is NOT bumped and recv loop is stopped + assert mux.stream_generation == 0 + assert mux._recv_task is None or mux._recv_task.done() + + # Given a subsequent successful reopen + new_stream = AsyncMock() + new_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + ok_factory = AsyncMock(return_value=new_stream) + + # When reopen is called again with the same generation + await mux.reopen_stream(0, ok_factory) + + # Then it succeeds + assert mux.stream_generation == 1 + assert mux._stream is new_stream + assert mux._recv_task is not None and not mux._recv_task.done() + + +class TestClose: + @pytest.mark.asyncio + async def test_close_cancels_recv_loop(self): + # Given a multiplexer with an active receive loop + mock_stream = AsyncMock() + mock_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + mux = _StreamMultiplexer(mock_stream) + mux.register({1}) + mux._ensure_recv_loop() + recv_task = mux._recv_task + + # When closing the multiplexer + await mux.close() + + # Then the receive loop task is cancelled + assert recv_task.cancelled() + + @pytest.mark.asyncio + async def test_close_broadcasts_terminal_error(self): + # Given a multiplexer with registered queues + mock_stream = AsyncMock() + mock_stream.recv = AsyncMock(side_effect=asyncio.Event().wait) + mux = _StreamMultiplexer(mock_stream) + q1 = mux.register({1}) + q2 = mux.register({2}) + + # When closing the multiplexer + await mux.close() + + # Then a terminal StreamError is broadcast to all queues + err1 = q1.get_nowait() + err2 = q2.get_nowait() + assert isinstance(err1, _StreamError) + assert isinstance(err2, _StreamError) + + @pytest.mark.asyncio + async def test_close_with_no_tasks_is_noop(self): + # Given a multiplexer with no active tasks + mock_stream = AsyncMock() + mux = _StreamMultiplexer(mock_stream) + + # When closing the multiplexer + # Then it should not raise any error + await mux.close() # should not raise