diff --git a/sdk/storage/azure-storage-blob/assets.json b/sdk/storage/azure-storage-blob/assets.json index 6f8a93a4c09b..bae9562aebe0 100644 --- a/sdk/storage/azure-storage-blob/assets.json +++ b/sdk/storage/azure-storage-blob/assets.json @@ -2,5 +2,5 @@ "AssetsRepo": "Azure/azure-sdk-assets", "AssetsRepoPrefixPath": "python", "TagPrefix": "python/storage/azure-storage-blob", - "Tag": "python/storage/azure-storage-blob_e0a670a6a4" + "Tag": "python/storage/azure-storage-blob_89384f00b6" } diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams.py index 27272fdac592..cb745693921c 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams.py @@ -7,7 +7,7 @@ import math import sys from enum import auto, Enum, IntFlag -from io import BytesIO, IOBase, UnsupportedOperation, SEEK_CUR, SEEK_END, SEEK_SET +from io import BytesIO, IOBase, UnsupportedOperation, SEEK_SET from typing import IO, Iterator, Optional from .validation import calculate_crc64 @@ -88,9 +88,6 @@ class StructuredMessageEncodeStream(IOBase): # pylint: disable=too-many-instanc _current_region_length: int _current_region_offset: int - _checksum_offset: int - """Tracks the offset the checksum has been calculated up to for seeking purposes""" - _message_crc64: int _segment_crc64s: dict[int, int] @@ -121,7 +118,6 @@ def __init__( self._current_region_length = self._message_header_length self._current_region_offset = 0 - self._checksum_offset = 0 self._message_crc64 = 0 self._segment_crc64s = {} @@ -171,9 +167,15 @@ def _update_current_region_length(self) -> None: def __len__(self): return self.message_length + @property + def closed(self) -> bool: + return self._inner_stream.closed + def close(self) -> None: - self._inner_stream.close() - super().close() + # Do not close the inner stream or this stream. + # The inner stream is caller-owned and must survive for retries. + # This stream may be re-read after a seek(0) on retry. + pass def readable(self) -> bool: return True @@ -224,66 +226,23 @@ def seek(self, offset: int, whence: int = SEEK_SET) -> int: if not self.seekable(): raise UnsupportedOperation("Inner stream is not seekable.") - if whence == SEEK_SET: - position = offset - elif whence == SEEK_CUR: - position = self.tell() + offset - elif whence == SEEK_END: - position = self.message_length + offset - else: - raise ValueError(f"Invalid value for whence: {whence}") - - if position < 0: - raise ValueError(f"Cannot seek to negative position: {position}") - if position > self.tell(): - raise UnsupportedOperation("This stream only supports seeking backwards.") - - # MESSAGE_HEADER - if position < self._message_header_length: - self._current_region = SMRegion.MESSAGE_HEADER - self._current_region_offset = position - self._content_offset = 0 - self._current_segment_number = 0 - # MESSAGE_FOOTER - elif position >= self.message_length - self._message_footer_length: - self._current_region = SMRegion.MESSAGE_FOOTER - self._current_region_offset = position - (self.message_length - self._message_footer_length) - self._content_offset = self.content_length - self._current_segment_number = self._num_segments - else: - # The size of a "full" segment. Fine to use for calculating new segment number and pos - full_segment_size = self._segment_header_length + self._segment_size + self._segment_footer_length - new_segment_num = 1 + (position - self._message_header_length) // full_segment_size - segment_pos = (position - self._message_header_length) % full_segment_size - previous_segments_total_content_size = (new_segment_num - 1) * self._segment_size - - # We need the size of the segment we are seeking to for some of the calculations below - new_segment_size = self._segment_size - if new_segment_num == self._num_segments: - # The last segment size is the remaining content length - new_segment_size = self.content_length - previous_segments_total_content_size - - # SEGMENT_HEADER - if segment_pos < self._segment_header_length: - self._current_region = SMRegion.SEGMENT_HEADER - self._current_region_offset = segment_pos - self._content_offset = previous_segments_total_content_size - # SEGMENT_CONTENT - elif segment_pos < self._segment_header_length + new_segment_size: - self._current_region = SMRegion.SEGMENT_CONTENT - self._current_region_offset = segment_pos - self._segment_header_length - self._content_offset = previous_segments_total_content_size + self._current_region_offset - # SEGMENT_FOOTER - else: - self._current_region = SMRegion.SEGMENT_FOOTER - self._current_region_offset = segment_pos - self._segment_header_length - new_segment_size - self._content_offset = previous_segments_total_content_size + new_segment_size + if whence != SEEK_SET: + raise UnsupportedOperation("This stream only supports SEEK_SET.") - self._current_segment_number = new_segment_num + if offset != 0: + raise UnsupportedOperation("This stream only supports seeking to position 0.") - self._update_current_region_length() - self._inner_stream.seek((self._initial_content_position or 0) + self._content_offset) - return position + # Reset to initial state + self._content_offset = 0 + self._current_segment_number = 0 + self._current_region = SMRegion.MESSAGE_HEADER + self._current_region_length = self._message_header_length + self._current_region_offset = 0 + self._message_crc64 = 0 + self._segment_crc64s = {} + + self._inner_stream.seek(self._initial_content_position or 0) + return 0 def read(self, size: int = -1) -> bytes: if self.closed: # pylint: disable=using-constant-test @@ -386,14 +345,7 @@ def _read_metadata_region(self, region: SMRegion, size: int, output: BytesIO) -> return read_size def _read_content(self, size: int, output: BytesIO) -> int: - # Will be non-zero if there is data to read that does not need to have checksum calculated. - # Will always be positive as stream can only seek backwards. - checksum_offset = self._checksum_offset - self._content_offset - read_size = min(size, self._current_region_length - self._current_region_offset) - if checksum_offset != 0: - # Only read up to checksum offset this iteration - read_size = min(read_size, checksum_offset) content = self._inner_stream.read(read_size) if len(content) != read_size: @@ -401,16 +353,12 @@ def _read_content(self, size: int, output: BytesIO) -> int: output.write(content) if StructuredMessageProperties.CRC64 in self.flags: - if checksum_offset == 0: - self._segment_crc64s[self._current_segment_number] = calculate_crc64( - content, self._segment_crc64s[self._current_segment_number] - ) - self._message_crc64 = calculate_crc64(content, self._message_crc64) + self._segment_crc64s[self._current_segment_number] = calculate_crc64( + content, self._segment_crc64s[self._current_segment_number] + ) + self._message_crc64 = calculate_crc64(content, self._message_crc64) self._content_offset += read_size - # Only update the checksum offset if we've read new data - if self._content_offset > self._checksum_offset: - self._checksum_offset += read_size self._current_region_offset += read_size if self._current_region_offset == self._current_region_length: self._advance_region(SMRegion.SEGMENT_CONTENT) diff --git a/sdk/storage/azure-storage-blob/tests/test_content_validation.py b/sdk/storage/azure-storage-blob/tests/test_content_validation.py index e2f65b03771e..f4e8fd759221 100644 --- a/sdk/storage/azure-storage-blob/tests/test_content_validation.py +++ b/sdk/storage/azure-storage-blob/tests/test_content_validation.py @@ -567,3 +567,47 @@ def download_hook_fail_once(response): content = downloader.read() assert download_call_count == 2 # Original + retry assert content == data + + @BlobPreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy + def test_streaming_with_retry(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + # Setup with retry enabled + token_credential = self.get_credential(BlobServiceClient) + self.bsc = BlobServiceClient( + self.account_url(storage_account_name, "blob"), + token_credential, + retry_total=1, + initial_backoff=0.1, + increment_base=0.1, + logging_enable=True + ) + self.container = self.bsc.get_container_client(self.get_resource_name('utcontainer')) + try: + self.container.create_container() + except ResourceExistsError: + pass + blob = self.container.get_blob_client(self._get_blob_reference()) + + content = b'abc' * 512 + assert_method = assert_structured_message if a == 'crc64' else assert_content_md5 + + call_count = 0 + def hook_fail_once(response): + nonlocal call_count + call_count += 1 + # Assert content validation headers are present on both attempts + assert_method(response) + if call_count == 1: + response.http_response.status_code = 408 # Request Timeout - triggers retry + + # Use stage_block to test structured message streaming + blob.stage_block('1', BytesIO(content), validate_content=a, raw_response_hook=hook_fail_once) + assert call_count == 2 # Original + retry + + blob.commit_block_list([BlobBlock('1')]) + result = blob.download_blob() + assert result.read() == content diff --git a/sdk/storage/azure-storage-blob/tests/test_content_validation_async.py b/sdk/storage/azure-storage-blob/tests/test_content_validation_async.py index 6b1114b81a44..727fa121a9fc 100644 --- a/sdk/storage/azure-storage-blob/tests/test_content_validation_async.py +++ b/sdk/storage/azure-storage-blob/tests/test_content_validation_async.py @@ -542,3 +542,48 @@ def download_hook_fail_once(response): content = await downloader.read() assert download_call_count == 2 # Original + retry assert content == data + + @BlobPreparer() + @pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content + @GenericTestProxyParametrize1() + @recorded_by_proxy_async + async def test_streaming_with_retry(self, a, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + # Setup with retry enabled + token_credential = self.get_credential(BlobServiceClient, is_async=True) + self.bsc = BlobServiceClient( + self.account_url(storage_account_name, "blob"), + token_credential, + retry_total=1, + initial_backoff=0.1, + increment_base=0.1, + logging_enable=True + ) + self.container = self.bsc.get_container_client(self.get_resource_name('utcontainer')) + try: + await self.container.create_container() + except ResourceExistsError: + pass + blob = self.container.get_blob_client(self._get_blob_reference()) + + content = b'abc' * 512 + assert_method = assert_structured_message if a == 'crc64' else assert_content_md5 + + # Test stage_block streaming with retry + call_count = 0 + def hook_fail_once(response): + nonlocal call_count + call_count += 1 + # Assert content validation headers are present on both attempts + assert_method(response) + if call_count == 1: + response.http_response.status_code = 408 # Request Timeout - triggers retry + + # Use stage_block to test structured message streaming + await blob.stage_block('1', BytesIO(content), validate_content=a, raw_response_hook=hook_fail_once) + assert call_count == 2 # Original + retry + + await blob.commit_block_list([BlobBlock('1')]) + result = await blob.download_blob() + assert await result.read() == content diff --git a/sdk/storage/azure-storage-blob/tests/test_streams.py b/sdk/storage/azure-storage-blob/tests/test_streams.py index 874c8e4a912f..e6bbb0d6414c 100644 --- a/sdk/storage/azure-storage-blob/tests/test_streams.py +++ b/sdk/storage/azure-storage-blob/tests/test_streams.py @@ -107,10 +107,12 @@ def test_close(self): assert not stream.closed assert not inner.closed - stream.close() - assert stream.closed - assert inner.closed + stream.close() # no-op + assert not stream.closed + assert not inner.closed + assert stream.read(1) is not None + inner.close() # closing inner will block reads with pytest.raises(ValueError): stream.read(0) @@ -226,28 +228,24 @@ def test_not_seekable(self): with pytest.raises(UnsupportedOperation): sm_stream.seek(0) - def test_seek_whence(self): + def test_seek(self): data = os.urandom(10) inner_stream = BytesIO(data) sm_stream = StructuredMessageEncodeStream(inner_stream, len(data), StructuredMessageProperties.CRC64) - # Read so we can seek backwards sm_stream.read(25) - pos = sm_stream.seek(10, SEEK_SET) - assert pos == 10 - pos = sm_stream.seek(-len(sm_stream) + 9, SEEK_END) - assert pos == 9 - pos = sm_stream.seek(-5, SEEK_CUR) - assert pos == 4 + with pytest.raises(UnsupportedOperation): + sm_stream.seek(5) - def test_seek_forward(self): - data = os.urandom(10) - inner_stream = BytesIO(data) - sm_stream = StructuredMessageEncodeStream(inner_stream, len(data), StructuredMessageProperties.CRC64) + with pytest.raises(UnsupportedOperation): + sm_stream.seek(0, SEEK_CUR) - sm_stream.read(5) with pytest.raises(UnsupportedOperation): - sm_stream.seek(10) + sm_stream.seek(0, SEEK_END) + + # Only SEEK_SET to 0 is supported + pos = sm_stream.seek(0, SEEK_SET) + assert pos == 0 @pytest.mark.parametrize("initial_read, segment_size, flags", [ # Single segment @@ -290,74 +288,6 @@ def test_seek_reverse_beginning(self, initial_read, segment_size, flags): result = sm_stream.read() assert result == expected - @pytest.mark.parametrize("initial_read, seek_offset, segment_size, flags", [ - # Single segment - (10, 5, 2048, StructuredMessageProperties.NONE), # Message header -> Message header - (10, 5, 2048, StructuredMessageProperties.CRC64), - (20, 15, 2048, StructuredMessageProperties.NONE), # Segment header -> Segment header - (20, 15, 2048, StructuredMessageProperties.CRC64), - (100, 50, 2048, StructuredMessageProperties.NONE), # First segment content -> First segment content - (100, 50, 2048, StructuredMessageProperties.CRC64), - (1000, 900, 2048, StructuredMessageProperties.NONE), # Second segment content -> Second segment content - (1000, 900, 2048, StructuredMessageProperties.CRC64), - (530, 525, 2048, StructuredMessageProperties.CRC64), # Segment footer -> Segment footer - (1060, 1050, 2048, StructuredMessageProperties.CRC64), # Message footer -> Segment footer - (1000, 100, 2048, StructuredMessageProperties.NONE), # Second segment content -> First segment content - (1000, 100, 2048, StructuredMessageProperties.CRC64), - (1000, 20, 2048, StructuredMessageProperties.NONE), # Second segment content -> First segment header - (1000, 20, 2048, StructuredMessageProperties.CRC64), - (1000, 530, 2048, StructuredMessageProperties.CRC64), # Second segment content -> First segment footer - (1097, 100, 2048, StructuredMessageProperties.CRC64), # Message footer -> First segment content - # Multiple segments - (10, 5, 500, StructuredMessageProperties.NONE), # Message header -> Message header - (10, 5, 500, StructuredMessageProperties.CRC64), - (20, 15, 500, StructuredMessageProperties.NONE), # Segment header -> Segment header - (20, 15, 500, StructuredMessageProperties.CRC64), - (100, 50, 500, StructuredMessageProperties.NONE), # First segment content -> First segment content - (100, 50, 500, StructuredMessageProperties.CRC64), - (1000, 900, 500, StructuredMessageProperties.NONE), # Second segment content -> Second segment content - (1000, 900, 500, StructuredMessageProperties.CRC64), - (530, 525, 500, StructuredMessageProperties.CRC64), # Segment footer -> Segment footer - (1097, 1090, 500, StructuredMessageProperties.CRC64), # Message footer -> Segment footer - (1000, 100, 500, StructuredMessageProperties.NONE), # Second segment content -> First segment content - (1000, 100, 500, StructuredMessageProperties.CRC64), - (1000, 20, 500, StructuredMessageProperties.NONE), # Second segment content -> First segment header - (1000, 20, 500, StructuredMessageProperties.CRC64), - (1000, 530, 500, StructuredMessageProperties.CRC64), # Second segment content -> First segment footer - (1097, 100, 500, StructuredMessageProperties.CRC64), # Message footer -> First segment content - ]) - def test_seek_reverse_middle(self, initial_read, seek_offset, segment_size, flags): - data = os.urandom(1024) - inner_stream = BytesIO(data) - sm_stream = StructuredMessageEncodeStream(inner_stream, len(data), flags, segment_size=segment_size) - expected = _build_structured_message(data, segment_size, flags)[0].getvalue() - - initial = sm_stream.read(initial_read) - assert initial == expected[:initial_read] - - sm_stream.seek(seek_offset) - result = sm_stream.read() - assert result == expected[seek_offset:] - - @pytest.mark.parametrize("flags", [StructuredMessageProperties.NONE, StructuredMessageProperties.CRC64]) - def test_seek_reverse_random(self, flags): - data = os.urandom(1024) - expected = _build_structured_message(data, 500, flags)[0].getvalue() - - for _ in range(10): - inner_stream = BytesIO(data) - sm_stream = StructuredMessageEncodeStream(inner_stream, len(data), flags, segment_size=500) - - initial_read = random.randint(5, len(data)) - seek_offset = random.randint(0, initial_read) - - initial = sm_stream.read(initial_read) - assert initial == expected[:initial_read] - - sm_stream.seek(seek_offset) - result = sm_stream.read() - assert result == expected[seek_offset:] - @pytest.mark.parametrize("flags", [StructuredMessageProperties.NONE, StructuredMessageProperties.CRC64]) def test_partial_stream_read(self, flags): data = os.urandom(1024) @@ -390,26 +320,7 @@ def test_partial_stream_seek_beginning(self, flags): result = sm_stream.read() assert result == expected - @pytest.mark.parametrize("flags", [StructuredMessageProperties.NONE, StructuredMessageProperties.CRC64]) - def test_partial_stream_seek_middle(self, flags): - data = os.urandom(1024) - partial_read = 100 - - inner_stream = BytesIO(data) - inner_stream.seek(partial_read) - expected = _build_structured_message(data[partial_read:], 500, flags)[0].getvalue() - sm_stream = StructuredMessageEncodeStream(inner_stream, len(data) - partial_read, flags, segment_size=500) - initial = sm_stream.read(501) - assert initial == expected[:501] - - sm_stream.seek(100) - assert inner_stream.tell() == partial_read + (100 - - StructuredMessageConstants.V1_HEADER_LENGTH - - StructuredMessageConstants.V1_SEGMENT_HEADER_LENGTH) - - result = sm_stream.read() - assert result == expected[100:] class TestStructuredMessageDecoder: diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/streams.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/streams.py index 27272fdac592..cb745693921c 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/streams.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/streams.py @@ -7,7 +7,7 @@ import math import sys from enum import auto, Enum, IntFlag -from io import BytesIO, IOBase, UnsupportedOperation, SEEK_CUR, SEEK_END, SEEK_SET +from io import BytesIO, IOBase, UnsupportedOperation, SEEK_SET from typing import IO, Iterator, Optional from .validation import calculate_crc64 @@ -88,9 +88,6 @@ class StructuredMessageEncodeStream(IOBase): # pylint: disable=too-many-instanc _current_region_length: int _current_region_offset: int - _checksum_offset: int - """Tracks the offset the checksum has been calculated up to for seeking purposes""" - _message_crc64: int _segment_crc64s: dict[int, int] @@ -121,7 +118,6 @@ def __init__( self._current_region_length = self._message_header_length self._current_region_offset = 0 - self._checksum_offset = 0 self._message_crc64 = 0 self._segment_crc64s = {} @@ -171,9 +167,15 @@ def _update_current_region_length(self) -> None: def __len__(self): return self.message_length + @property + def closed(self) -> bool: + return self._inner_stream.closed + def close(self) -> None: - self._inner_stream.close() - super().close() + # Do not close the inner stream or this stream. + # The inner stream is caller-owned and must survive for retries. + # This stream may be re-read after a seek(0) on retry. + pass def readable(self) -> bool: return True @@ -224,66 +226,23 @@ def seek(self, offset: int, whence: int = SEEK_SET) -> int: if not self.seekable(): raise UnsupportedOperation("Inner stream is not seekable.") - if whence == SEEK_SET: - position = offset - elif whence == SEEK_CUR: - position = self.tell() + offset - elif whence == SEEK_END: - position = self.message_length + offset - else: - raise ValueError(f"Invalid value for whence: {whence}") - - if position < 0: - raise ValueError(f"Cannot seek to negative position: {position}") - if position > self.tell(): - raise UnsupportedOperation("This stream only supports seeking backwards.") - - # MESSAGE_HEADER - if position < self._message_header_length: - self._current_region = SMRegion.MESSAGE_HEADER - self._current_region_offset = position - self._content_offset = 0 - self._current_segment_number = 0 - # MESSAGE_FOOTER - elif position >= self.message_length - self._message_footer_length: - self._current_region = SMRegion.MESSAGE_FOOTER - self._current_region_offset = position - (self.message_length - self._message_footer_length) - self._content_offset = self.content_length - self._current_segment_number = self._num_segments - else: - # The size of a "full" segment. Fine to use for calculating new segment number and pos - full_segment_size = self._segment_header_length + self._segment_size + self._segment_footer_length - new_segment_num = 1 + (position - self._message_header_length) // full_segment_size - segment_pos = (position - self._message_header_length) % full_segment_size - previous_segments_total_content_size = (new_segment_num - 1) * self._segment_size - - # We need the size of the segment we are seeking to for some of the calculations below - new_segment_size = self._segment_size - if new_segment_num == self._num_segments: - # The last segment size is the remaining content length - new_segment_size = self.content_length - previous_segments_total_content_size - - # SEGMENT_HEADER - if segment_pos < self._segment_header_length: - self._current_region = SMRegion.SEGMENT_HEADER - self._current_region_offset = segment_pos - self._content_offset = previous_segments_total_content_size - # SEGMENT_CONTENT - elif segment_pos < self._segment_header_length + new_segment_size: - self._current_region = SMRegion.SEGMENT_CONTENT - self._current_region_offset = segment_pos - self._segment_header_length - self._content_offset = previous_segments_total_content_size + self._current_region_offset - # SEGMENT_FOOTER - else: - self._current_region = SMRegion.SEGMENT_FOOTER - self._current_region_offset = segment_pos - self._segment_header_length - new_segment_size - self._content_offset = previous_segments_total_content_size + new_segment_size + if whence != SEEK_SET: + raise UnsupportedOperation("This stream only supports SEEK_SET.") - self._current_segment_number = new_segment_num + if offset != 0: + raise UnsupportedOperation("This stream only supports seeking to position 0.") - self._update_current_region_length() - self._inner_stream.seek((self._initial_content_position or 0) + self._content_offset) - return position + # Reset to initial state + self._content_offset = 0 + self._current_segment_number = 0 + self._current_region = SMRegion.MESSAGE_HEADER + self._current_region_length = self._message_header_length + self._current_region_offset = 0 + self._message_crc64 = 0 + self._segment_crc64s = {} + + self._inner_stream.seek(self._initial_content_position or 0) + return 0 def read(self, size: int = -1) -> bytes: if self.closed: # pylint: disable=using-constant-test @@ -386,14 +345,7 @@ def _read_metadata_region(self, region: SMRegion, size: int, output: BytesIO) -> return read_size def _read_content(self, size: int, output: BytesIO) -> int: - # Will be non-zero if there is data to read that does not need to have checksum calculated. - # Will always be positive as stream can only seek backwards. - checksum_offset = self._checksum_offset - self._content_offset - read_size = min(size, self._current_region_length - self._current_region_offset) - if checksum_offset != 0: - # Only read up to checksum offset this iteration - read_size = min(read_size, checksum_offset) content = self._inner_stream.read(read_size) if len(content) != read_size: @@ -401,16 +353,12 @@ def _read_content(self, size: int, output: BytesIO) -> int: output.write(content) if StructuredMessageProperties.CRC64 in self.flags: - if checksum_offset == 0: - self._segment_crc64s[self._current_segment_number] = calculate_crc64( - content, self._segment_crc64s[self._current_segment_number] - ) - self._message_crc64 = calculate_crc64(content, self._message_crc64) + self._segment_crc64s[self._current_segment_number] = calculate_crc64( + content, self._segment_crc64s[self._current_segment_number] + ) + self._message_crc64 = calculate_crc64(content, self._message_crc64) self._content_offset += read_size - # Only update the checksum offset if we've read new data - if self._content_offset > self._checksum_offset: - self._checksum_offset += read_size self._current_region_offset += read_size if self._current_region_offset == self._current_region_length: self._advance_region(SMRegion.SEGMENT_CONTENT) diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/streams.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/streams.py index 27272fdac592..cb745693921c 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/streams.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/streams.py @@ -7,7 +7,7 @@ import math import sys from enum import auto, Enum, IntFlag -from io import BytesIO, IOBase, UnsupportedOperation, SEEK_CUR, SEEK_END, SEEK_SET +from io import BytesIO, IOBase, UnsupportedOperation, SEEK_SET from typing import IO, Iterator, Optional from .validation import calculate_crc64 @@ -88,9 +88,6 @@ class StructuredMessageEncodeStream(IOBase): # pylint: disable=too-many-instanc _current_region_length: int _current_region_offset: int - _checksum_offset: int - """Tracks the offset the checksum has been calculated up to for seeking purposes""" - _message_crc64: int _segment_crc64s: dict[int, int] @@ -121,7 +118,6 @@ def __init__( self._current_region_length = self._message_header_length self._current_region_offset = 0 - self._checksum_offset = 0 self._message_crc64 = 0 self._segment_crc64s = {} @@ -171,9 +167,15 @@ def _update_current_region_length(self) -> None: def __len__(self): return self.message_length + @property + def closed(self) -> bool: + return self._inner_stream.closed + def close(self) -> None: - self._inner_stream.close() - super().close() + # Do not close the inner stream or this stream. + # The inner stream is caller-owned and must survive for retries. + # This stream may be re-read after a seek(0) on retry. + pass def readable(self) -> bool: return True @@ -224,66 +226,23 @@ def seek(self, offset: int, whence: int = SEEK_SET) -> int: if not self.seekable(): raise UnsupportedOperation("Inner stream is not seekable.") - if whence == SEEK_SET: - position = offset - elif whence == SEEK_CUR: - position = self.tell() + offset - elif whence == SEEK_END: - position = self.message_length + offset - else: - raise ValueError(f"Invalid value for whence: {whence}") - - if position < 0: - raise ValueError(f"Cannot seek to negative position: {position}") - if position > self.tell(): - raise UnsupportedOperation("This stream only supports seeking backwards.") - - # MESSAGE_HEADER - if position < self._message_header_length: - self._current_region = SMRegion.MESSAGE_HEADER - self._current_region_offset = position - self._content_offset = 0 - self._current_segment_number = 0 - # MESSAGE_FOOTER - elif position >= self.message_length - self._message_footer_length: - self._current_region = SMRegion.MESSAGE_FOOTER - self._current_region_offset = position - (self.message_length - self._message_footer_length) - self._content_offset = self.content_length - self._current_segment_number = self._num_segments - else: - # The size of a "full" segment. Fine to use for calculating new segment number and pos - full_segment_size = self._segment_header_length + self._segment_size + self._segment_footer_length - new_segment_num = 1 + (position - self._message_header_length) // full_segment_size - segment_pos = (position - self._message_header_length) % full_segment_size - previous_segments_total_content_size = (new_segment_num - 1) * self._segment_size - - # We need the size of the segment we are seeking to for some of the calculations below - new_segment_size = self._segment_size - if new_segment_num == self._num_segments: - # The last segment size is the remaining content length - new_segment_size = self.content_length - previous_segments_total_content_size - - # SEGMENT_HEADER - if segment_pos < self._segment_header_length: - self._current_region = SMRegion.SEGMENT_HEADER - self._current_region_offset = segment_pos - self._content_offset = previous_segments_total_content_size - # SEGMENT_CONTENT - elif segment_pos < self._segment_header_length + new_segment_size: - self._current_region = SMRegion.SEGMENT_CONTENT - self._current_region_offset = segment_pos - self._segment_header_length - self._content_offset = previous_segments_total_content_size + self._current_region_offset - # SEGMENT_FOOTER - else: - self._current_region = SMRegion.SEGMENT_FOOTER - self._current_region_offset = segment_pos - self._segment_header_length - new_segment_size - self._content_offset = previous_segments_total_content_size + new_segment_size + if whence != SEEK_SET: + raise UnsupportedOperation("This stream only supports SEEK_SET.") - self._current_segment_number = new_segment_num + if offset != 0: + raise UnsupportedOperation("This stream only supports seeking to position 0.") - self._update_current_region_length() - self._inner_stream.seek((self._initial_content_position or 0) + self._content_offset) - return position + # Reset to initial state + self._content_offset = 0 + self._current_segment_number = 0 + self._current_region = SMRegion.MESSAGE_HEADER + self._current_region_length = self._message_header_length + self._current_region_offset = 0 + self._message_crc64 = 0 + self._segment_crc64s = {} + + self._inner_stream.seek(self._initial_content_position or 0) + return 0 def read(self, size: int = -1) -> bytes: if self.closed: # pylint: disable=using-constant-test @@ -386,14 +345,7 @@ def _read_metadata_region(self, region: SMRegion, size: int, output: BytesIO) -> return read_size def _read_content(self, size: int, output: BytesIO) -> int: - # Will be non-zero if there is data to read that does not need to have checksum calculated. - # Will always be positive as stream can only seek backwards. - checksum_offset = self._checksum_offset - self._content_offset - read_size = min(size, self._current_region_length - self._current_region_offset) - if checksum_offset != 0: - # Only read up to checksum offset this iteration - read_size = min(read_size, checksum_offset) content = self._inner_stream.read(read_size) if len(content) != read_size: @@ -401,16 +353,12 @@ def _read_content(self, size: int, output: BytesIO) -> int: output.write(content) if StructuredMessageProperties.CRC64 in self.flags: - if checksum_offset == 0: - self._segment_crc64s[self._current_segment_number] = calculate_crc64( - content, self._segment_crc64s[self._current_segment_number] - ) - self._message_crc64 = calculate_crc64(content, self._message_crc64) + self._segment_crc64s[self._current_segment_number] = calculate_crc64( + content, self._segment_crc64s[self._current_segment_number] + ) + self._message_crc64 = calculate_crc64(content, self._message_crc64) self._content_offset += read_size - # Only update the checksum offset if we've read new data - if self._content_offset > self._checksum_offset: - self._checksum_offset += read_size self._current_region_offset += read_size if self._current_region_offset == self._current_region_length: self._advance_region(SMRegion.SEGMENT_CONTENT) diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/streams.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/streams.py index 27272fdac592..cb745693921c 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/streams.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/streams.py @@ -7,7 +7,7 @@ import math import sys from enum import auto, Enum, IntFlag -from io import BytesIO, IOBase, UnsupportedOperation, SEEK_CUR, SEEK_END, SEEK_SET +from io import BytesIO, IOBase, UnsupportedOperation, SEEK_SET from typing import IO, Iterator, Optional from .validation import calculate_crc64 @@ -88,9 +88,6 @@ class StructuredMessageEncodeStream(IOBase): # pylint: disable=too-many-instanc _current_region_length: int _current_region_offset: int - _checksum_offset: int - """Tracks the offset the checksum has been calculated up to for seeking purposes""" - _message_crc64: int _segment_crc64s: dict[int, int] @@ -121,7 +118,6 @@ def __init__( self._current_region_length = self._message_header_length self._current_region_offset = 0 - self._checksum_offset = 0 self._message_crc64 = 0 self._segment_crc64s = {} @@ -171,9 +167,15 @@ def _update_current_region_length(self) -> None: def __len__(self): return self.message_length + @property + def closed(self) -> bool: + return self._inner_stream.closed + def close(self) -> None: - self._inner_stream.close() - super().close() + # Do not close the inner stream or this stream. + # The inner stream is caller-owned and must survive for retries. + # This stream may be re-read after a seek(0) on retry. + pass def readable(self) -> bool: return True @@ -224,66 +226,23 @@ def seek(self, offset: int, whence: int = SEEK_SET) -> int: if not self.seekable(): raise UnsupportedOperation("Inner stream is not seekable.") - if whence == SEEK_SET: - position = offset - elif whence == SEEK_CUR: - position = self.tell() + offset - elif whence == SEEK_END: - position = self.message_length + offset - else: - raise ValueError(f"Invalid value for whence: {whence}") - - if position < 0: - raise ValueError(f"Cannot seek to negative position: {position}") - if position > self.tell(): - raise UnsupportedOperation("This stream only supports seeking backwards.") - - # MESSAGE_HEADER - if position < self._message_header_length: - self._current_region = SMRegion.MESSAGE_HEADER - self._current_region_offset = position - self._content_offset = 0 - self._current_segment_number = 0 - # MESSAGE_FOOTER - elif position >= self.message_length - self._message_footer_length: - self._current_region = SMRegion.MESSAGE_FOOTER - self._current_region_offset = position - (self.message_length - self._message_footer_length) - self._content_offset = self.content_length - self._current_segment_number = self._num_segments - else: - # The size of a "full" segment. Fine to use for calculating new segment number and pos - full_segment_size = self._segment_header_length + self._segment_size + self._segment_footer_length - new_segment_num = 1 + (position - self._message_header_length) // full_segment_size - segment_pos = (position - self._message_header_length) % full_segment_size - previous_segments_total_content_size = (new_segment_num - 1) * self._segment_size - - # We need the size of the segment we are seeking to for some of the calculations below - new_segment_size = self._segment_size - if new_segment_num == self._num_segments: - # The last segment size is the remaining content length - new_segment_size = self.content_length - previous_segments_total_content_size - - # SEGMENT_HEADER - if segment_pos < self._segment_header_length: - self._current_region = SMRegion.SEGMENT_HEADER - self._current_region_offset = segment_pos - self._content_offset = previous_segments_total_content_size - # SEGMENT_CONTENT - elif segment_pos < self._segment_header_length + new_segment_size: - self._current_region = SMRegion.SEGMENT_CONTENT - self._current_region_offset = segment_pos - self._segment_header_length - self._content_offset = previous_segments_total_content_size + self._current_region_offset - # SEGMENT_FOOTER - else: - self._current_region = SMRegion.SEGMENT_FOOTER - self._current_region_offset = segment_pos - self._segment_header_length - new_segment_size - self._content_offset = previous_segments_total_content_size + new_segment_size + if whence != SEEK_SET: + raise UnsupportedOperation("This stream only supports SEEK_SET.") - self._current_segment_number = new_segment_num + if offset != 0: + raise UnsupportedOperation("This stream only supports seeking to position 0.") - self._update_current_region_length() - self._inner_stream.seek((self._initial_content_position or 0) + self._content_offset) - return position + # Reset to initial state + self._content_offset = 0 + self._current_segment_number = 0 + self._current_region = SMRegion.MESSAGE_HEADER + self._current_region_length = self._message_header_length + self._current_region_offset = 0 + self._message_crc64 = 0 + self._segment_crc64s = {} + + self._inner_stream.seek(self._initial_content_position or 0) + return 0 def read(self, size: int = -1) -> bytes: if self.closed: # pylint: disable=using-constant-test @@ -386,14 +345,7 @@ def _read_metadata_region(self, region: SMRegion, size: int, output: BytesIO) -> return read_size def _read_content(self, size: int, output: BytesIO) -> int: - # Will be non-zero if there is data to read that does not need to have checksum calculated. - # Will always be positive as stream can only seek backwards. - checksum_offset = self._checksum_offset - self._content_offset - read_size = min(size, self._current_region_length - self._current_region_offset) - if checksum_offset != 0: - # Only read up to checksum offset this iteration - read_size = min(read_size, checksum_offset) content = self._inner_stream.read(read_size) if len(content) != read_size: @@ -401,16 +353,12 @@ def _read_content(self, size: int, output: BytesIO) -> int: output.write(content) if StructuredMessageProperties.CRC64 in self.flags: - if checksum_offset == 0: - self._segment_crc64s[self._current_segment_number] = calculate_crc64( - content, self._segment_crc64s[self._current_segment_number] - ) - self._message_crc64 = calculate_crc64(content, self._message_crc64) + self._segment_crc64s[self._current_segment_number] = calculate_crc64( + content, self._segment_crc64s[self._current_segment_number] + ) + self._message_crc64 = calculate_crc64(content, self._message_crc64) self._content_offset += read_size - # Only update the checksum offset if we've read new data - if self._content_offset > self._checksum_offset: - self._checksum_offset += read_size self._current_region_offset += read_size if self._current_region_offset == self._current_region_length: self._advance_region(SMRegion.SEGMENT_CONTENT) diff --git a/sdk/storage/ci.yml b/sdk/storage/ci.yml index dd044ebb3515..5abd3aaf8520 100644 --- a/sdk/storage/ci.yml +++ b/sdk/storage/ci.yml @@ -54,6 +54,11 @@ extends: safeName: azurestoragequeue - name: azure-storage-extensions safeName: azurestorageextensions + triggeringPaths: + - /sdk/storage/azure-storage-blob + - /sdk/storage/azure-storage-file-datalake + - /sdk/storage/azure-storage-file-share + - /sdk/storage/azure-storage-queue # Pure C-based storage extension package, not generating docs at this moment. skipPublishDocGithubIo: true skipPublishDocMs: true