11from collections .abc import Iterable
22from dataclasses import dataclass
3+ from typing import cast
34
45import numpy as np
56import pytest
67
7- from zarr .abc .store import ByteRequest
8+ from zarr .abc .store import ByteGetter , ByteRequest
89from zarr .codecs .sharding import (
910 MAX_UINT_64 ,
1011 ShardingCodec ,
@@ -344,7 +345,7 @@ async def get_partial_values(
344345async def test_load_partial_shard_maybe_index_load_fails () -> None :
345346 """Test _load_partial_shard_maybe returns None when index load fails."""
346347 codec = ShardingCodec (chunk_shape = (8 ,))
347- byte_getter = MockByteGetterWithIndex (index_data = None , chunk_data = None )
348+ byte_getter = cast ( ByteGetter , MockByteGetterWithIndex (index_data = None , chunk_data = None ) )
348349
349350 chunks_per_shard = (2 ,)
350351 all_chunk_coords : set [tuple [int , ...]] = {(0 ,)}
@@ -381,7 +382,7 @@ async def mock_load_index(
381382 monkeypatch .setattr (ShardingCodec , "_load_shard_index_maybe" , mock_load_index )
382383
383384 chunk_data = b"x" * 300
384- byte_getter = MockByteGetter (data = chunk_data )
385+ byte_getter = cast ( ByteGetter , MockByteGetter (data = chunk_data ) )
385386
386387 # Request chunks including the empty one
387388 all_chunk_coords : set [tuple [int , ...]] = {(0 ,), (1 ,), (2 ,)}
@@ -417,7 +418,7 @@ async def mock_load_index(
417418
418419 monkeypatch .setattr (ShardingCodec , "_load_shard_index_maybe" , mock_load_index )
419420
420- byte_getter = MockByteGetter (data = b"" )
421+ byte_getter = cast ( ByteGetter , MockByteGetter (data = b"" ) )
421422
422423 # Request some chunks - all will be empty
423424 all_chunk_coords : set [tuple [int , ...]] = {(0 ,), (1 ,), (2 ,)}
@@ -452,7 +453,8 @@ async def mock_load_index(
452453 monkeypatch .setattr (ShardingCodec , "_load_shard_index_maybe" , mock_load_index )
453454
454455 chunk_data = b"A" * 100 + b"B" * 100
455- byte_getter = MockByteGetter (data = chunk_data )
456+ mock_getter = MockByteGetter (data = chunk_data )
457+ byte_getter = cast (ByteGetter , mock_getter )
456458
457459 all_chunk_coords : set [tuple [int , ...]] = {(0 ,), (1 ,)}
458460
@@ -469,7 +471,7 @@ async def mock_load_index(
469471 assert (1 ,) in result
470472
471473 # get_partial_values should have been called exactly once
472- assert byte_getter .get_partial_values_call_count == 1
474+ assert mock_getter .get_partial_values_call_count == 1
473475
474476
475477async def test_load_partial_shard_single_chunk_read (
@@ -490,7 +492,7 @@ async def mock_load_index(
490492 monkeypatch .setattr (ShardingCodec , "_load_shard_index_maybe" , mock_load_index )
491493
492494 chunk_data = b"\x00 " * 100 + b"E" * 100
493- byte_getter = MockByteGetter (data = chunk_data )
495+ byte_getter = cast ( ByteGetter , MockByteGetter (data = chunk_data ) )
494496
495497 all_chunk_coords : set [tuple [int , ...]] = {(1 ,)}
496498
@@ -527,8 +529,9 @@ async def mock_load_index(
527529
528530 monkeypatch .setattr (ShardingCodec , "_load_shard_index_maybe" , mock_load_index )
529531
530- byte_getter = MockByteGetterWithIndex (
531- index_data = b"" , chunk_data = None , return_none_for_chunks = True
532+ byte_getter = cast (
533+ ByteGetter ,
534+ MockByteGetterWithIndex (index_data = b"" , chunk_data = None , return_none_for_chunks = True ),
532535 )
533536
534537 all_chunk_coords : set [tuple [int , ...]] = {(0 ,)}
0 commit comments