Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdk/storage/azure-storage-blob/assets.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
"AssetsRepo": "Azure/azure-sdk-assets",
"AssetsRepoPrefixPath": "python",
"TagPrefix": "python/storage/azure-storage-blob",
"Tag": "python/storage/azure-storage-blob_e0a670a6a4"
"Tag": "python/storage/azure-storage-blob_89384f00b6"
}
106 changes: 27 additions & 79 deletions sdk/storage/azure-storage-blob/azure/storage/blob/_shared/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,6 @@ class StructuredMessageEncodeStream(IOBase): # pylint: disable=too-many-instanc
_current_region_length: int
_current_region_offset: int

_checksum_offset: int
"""Tracks the offset the checksum has been calculated up to for seeking purposes"""

_message_crc64: int
_segment_crc64s: dict[int, int]

Expand Down Expand Up @@ -121,7 +118,6 @@ def __init__(
self._current_region_length = self._message_header_length
self._current_region_offset = 0

self._checksum_offset = 0
self._message_crc64 = 0
self._segment_crc64s = {}

Expand Down Expand Up @@ -171,9 +167,15 @@ def _update_current_region_length(self) -> None:
def __len__(self):
return self.message_length

@property
def closed(self) -> bool:
return self._inner_stream.closed

def close(self) -> None:
self._inner_stream.close()
super().close()
# Do not close the inner stream or this stream.
# The inner stream is caller-owned and must survive for retries.
# This stream may be re-read after a seek(0) on retry.
pass

def readable(self) -> bool:
return True
Expand Down Expand Up @@ -224,66 +226,23 @@ def seek(self, offset: int, whence: int = SEEK_SET) -> int:
if not self.seekable():
raise UnsupportedOperation("Inner stream is not seekable.")

if whence == SEEK_SET:
position = offset
elif whence == SEEK_CUR:
position = self.tell() + offset
elif whence == SEEK_END:
position = self.message_length + offset
else:
raise ValueError(f"Invalid value for whence: {whence}")

if position < 0:
raise ValueError(f"Cannot seek to negative position: {position}")
if position > self.tell():
raise UnsupportedOperation("This stream only supports seeking backwards.")

# MESSAGE_HEADER
if position < self._message_header_length:
self._current_region = SMRegion.MESSAGE_HEADER
self._current_region_offset = position
self._content_offset = 0
self._current_segment_number = 0
# MESSAGE_FOOTER
elif position >= self.message_length - self._message_footer_length:
self._current_region = SMRegion.MESSAGE_FOOTER
self._current_region_offset = position - (self.message_length - self._message_footer_length)
self._content_offset = self.content_length
self._current_segment_number = self._num_segments
else:
# The size of a "full" segment. Fine to use for calculating new segment number and pos
full_segment_size = self._segment_header_length + self._segment_size + self._segment_footer_length
new_segment_num = 1 + (position - self._message_header_length) // full_segment_size
segment_pos = (position - self._message_header_length) % full_segment_size
previous_segments_total_content_size = (new_segment_num - 1) * self._segment_size

# We need the size of the segment we are seeking to for some of the calculations below
new_segment_size = self._segment_size
if new_segment_num == self._num_segments:
# The last segment size is the remaining content length
new_segment_size = self.content_length - previous_segments_total_content_size

# SEGMENT_HEADER
if segment_pos < self._segment_header_length:
self._current_region = SMRegion.SEGMENT_HEADER
self._current_region_offset = segment_pos
self._content_offset = previous_segments_total_content_size
# SEGMENT_CONTENT
elif segment_pos < self._segment_header_length + new_segment_size:
self._current_region = SMRegion.SEGMENT_CONTENT
self._current_region_offset = segment_pos - self._segment_header_length
self._content_offset = previous_segments_total_content_size + self._current_region_offset
# SEGMENT_FOOTER
else:
self._current_region = SMRegion.SEGMENT_FOOTER
self._current_region_offset = segment_pos - self._segment_header_length - new_segment_size
self._content_offset = previous_segments_total_content_size + new_segment_size
if whence != SEEK_SET:
raise UnsupportedOperation("This stream only supports SEEK_SET.")

self._current_segment_number = new_segment_num
if offset != 0:
raise UnsupportedOperation("This stream only supports seeking to position 0.")

self._update_current_region_length()
self._inner_stream.seek((self._initial_content_position or 0) + self._content_offset)
return position
# Reset to initial state
self._content_offset = 0
self._current_segment_number = 0
self._current_region = SMRegion.MESSAGE_HEADER
self._current_region_length = self._message_header_length
self._current_region_offset = 0
self._message_crc64 = 0
self._segment_crc64s = {}

self._inner_stream.seek(self._initial_content_position or 0)
return 0

def read(self, size: int = -1) -> bytes:
if self.closed: # pylint: disable=using-constant-test
Expand Down Expand Up @@ -386,31 +345,20 @@ def _read_metadata_region(self, region: SMRegion, size: int, output: BytesIO) ->
return read_size

def _read_content(self, size: int, output: BytesIO) -> int:
# Will be non-zero if there is data to read that does not need to have checksum calculated.
# Will always be positive as stream can only seek backwards.
checksum_offset = self._checksum_offset - self._content_offset

read_size = min(size, self._current_region_length - self._current_region_offset)
if checksum_offset != 0:
# Only read up to checksum offset this iteration
read_size = min(read_size, checksum_offset)

content = self._inner_stream.read(read_size)
if len(content) != read_size:
raise ValueError("Content ended early when encoding structured message.")
output.write(content)

if StructuredMessageProperties.CRC64 in self.flags:
if checksum_offset == 0:
self._segment_crc64s[self._current_segment_number] = calculate_crc64(
content, self._segment_crc64s[self._current_segment_number]
)
self._message_crc64 = calculate_crc64(content, self._message_crc64)
self._segment_crc64s[self._current_segment_number] = calculate_crc64(
content, self._segment_crc64s[self._current_segment_number]
)
self._message_crc64 = calculate_crc64(content, self._message_crc64)

self._content_offset += read_size
# Only update the checksum offset if we've read new data
if self._content_offset > self._checksum_offset:
self._checksum_offset += read_size
self._current_region_offset += read_size
if self._current_region_offset == self._current_region_length:
self._advance_region(SMRegion.SEGMENT_CONTENT)
Expand Down
44 changes: 44 additions & 0 deletions sdk/storage/azure-storage-blob/tests/test_content_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,3 +567,47 @@ def download_hook_fail_once(response):
content = downloader.read()
assert download_call_count == 2 # Original + retry
assert content == data

@BlobPreparer()
@pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content
@GenericTestProxyParametrize1()
@recorded_by_proxy
def test_streaming_with_retry(self, a, **kwargs):
storage_account_name = kwargs.pop("storage_account_name")

# Setup with retry enabled
token_credential = self.get_credential(BlobServiceClient)
self.bsc = BlobServiceClient(
self.account_url(storage_account_name, "blob"),
token_credential,
retry_total=1,
initial_backoff=0.1,
increment_base=0.1,
logging_enable=True
)
self.container = self.bsc.get_container_client(self.get_resource_name('utcontainer'))
try:
self.container.create_container()
except ResourceExistsError:
pass
blob = self.container.get_blob_client(self._get_blob_reference())

content = b'abc' * 512
assert_method = assert_structured_message if a == 'crc64' else assert_content_md5

call_count = 0
def hook_fail_once(response):
nonlocal call_count
call_count += 1
# Assert content validation headers are present on both attempts
assert_method(response)
if call_count == 1:
response.http_response.status_code = 408 # Request Timeout - triggers retry

# Use stage_block to test structured message streaming
blob.stage_block('1', BytesIO(content), validate_content=a, raw_response_hook=hook_fail_once)
assert call_count == 2 # Original + retry

blob.commit_block_list([BlobBlock('1')])
result = blob.download_blob()
assert result.read() == content
Original file line number Diff line number Diff line change
Expand Up @@ -542,3 +542,48 @@ def download_hook_fail_once(response):
content = await downloader.read()
assert download_call_count == 2 # Original + retry
assert content == data

@BlobPreparer()
@pytest.mark.parametrize('a', [True, 'md5', 'crc64']) # a: validate_content
@GenericTestProxyParametrize1()
@recorded_by_proxy_async
async def test_streaming_with_retry(self, a, **kwargs):
storage_account_name = kwargs.pop("storage_account_name")

# Setup with retry enabled
token_credential = self.get_credential(BlobServiceClient, is_async=True)
self.bsc = BlobServiceClient(
self.account_url(storage_account_name, "blob"),
token_credential,
retry_total=1,
initial_backoff=0.1,
increment_base=0.1,
logging_enable=True
)
self.container = self.bsc.get_container_client(self.get_resource_name('utcontainer'))
try:
await self.container.create_container()
except ResourceExistsError:
pass
blob = self.container.get_blob_client(self._get_blob_reference())

content = b'abc' * 512
assert_method = assert_structured_message if a == 'crc64' else assert_content_md5

# Test stage_block streaming with retry
call_count = 0
def hook_fail_once(response):
nonlocal call_count
call_count += 1
# Assert content validation headers are present on both attempts
assert_method(response)
if call_count == 1:
response.http_response.status_code = 408 # Request Timeout - triggers retry

# Use stage_block to test structured message streaming
await blob.stage_block('1', BytesIO(content), validate_content=a, raw_response_hook=hook_fail_once)
assert call_count == 2 # Original + retry

await blob.commit_block_list([BlobBlock('1')])
result = await blob.download_blob()
assert await result.read() == content
Loading
Loading