Skip to content

Commit 80e6efe

Browse files
committed
tests: work around mypy not seeing equivalent protocols as equivalent
1 parent 36d3909 commit 80e6efe

1 file changed

Lines changed: 12 additions & 9 deletions

File tree

tests/test_codecs/test_sharding_unit.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from collections.abc import Iterable
22
from dataclasses import dataclass
3+
from typing import cast
34

45
import numpy as np
56
import pytest
67

7-
from zarr.abc.store import ByteRequest
8+
from zarr.abc.store import ByteGetter, ByteRequest
89
from zarr.codecs.sharding import (
910
MAX_UINT_64,
1011
ShardingCodec,
@@ -344,7 +345,7 @@ async def get_partial_values(
344345
async def test_load_partial_shard_maybe_index_load_fails() -> None:
345346
"""Test _load_partial_shard_maybe returns None when index load fails."""
346347
codec = ShardingCodec(chunk_shape=(8,))
347-
byte_getter = MockByteGetterWithIndex(index_data=None, chunk_data=None)
348+
byte_getter = cast(ByteGetter, MockByteGetterWithIndex(index_data=None, chunk_data=None))
348349

349350
chunks_per_shard = (2,)
350351
all_chunk_coords: set[tuple[int, ...]] = {(0,)}
@@ -381,7 +382,7 @@ async def mock_load_index(
381382
monkeypatch.setattr(ShardingCodec, "_load_shard_index_maybe", mock_load_index)
382383

383384
chunk_data = b"x" * 300
384-
byte_getter = MockByteGetter(data=chunk_data)
385+
byte_getter = cast(ByteGetter, MockByteGetter(data=chunk_data))
385386

386387
# Request chunks including the empty one
387388
all_chunk_coords: set[tuple[int, ...]] = {(0,), (1,), (2,)}
@@ -417,7 +418,7 @@ async def mock_load_index(
417418

418419
monkeypatch.setattr(ShardingCodec, "_load_shard_index_maybe", mock_load_index)
419420

420-
byte_getter = MockByteGetter(data=b"")
421+
byte_getter = cast(ByteGetter, MockByteGetter(data=b""))
421422

422423
# Request some chunks - all will be empty
423424
all_chunk_coords: set[tuple[int, ...]] = {(0,), (1,), (2,)}
@@ -452,7 +453,8 @@ async def mock_load_index(
452453
monkeypatch.setattr(ShardingCodec, "_load_shard_index_maybe", mock_load_index)
453454

454455
chunk_data = b"A" * 100 + b"B" * 100
455-
byte_getter = MockByteGetter(data=chunk_data)
456+
mock_getter = MockByteGetter(data=chunk_data)
457+
byte_getter = cast(ByteGetter, mock_getter)
456458

457459
all_chunk_coords: set[tuple[int, ...]] = {(0,), (1,)}
458460

@@ -469,7 +471,7 @@ async def mock_load_index(
469471
assert (1,) in result
470472

471473
# get_partial_values should have been called exactly once
472-
assert byte_getter.get_partial_values_call_count == 1
474+
assert mock_getter.get_partial_values_call_count == 1
473475

474476

475477
async def test_load_partial_shard_single_chunk_read(
@@ -490,7 +492,7 @@ async def mock_load_index(
490492
monkeypatch.setattr(ShardingCodec, "_load_shard_index_maybe", mock_load_index)
491493

492494
chunk_data = b"\x00" * 100 + b"E" * 100
493-
byte_getter = MockByteGetter(data=chunk_data)
495+
byte_getter = cast(ByteGetter, MockByteGetter(data=chunk_data))
494496

495497
all_chunk_coords: set[tuple[int, ...]] = {(1,)}
496498

@@ -527,8 +529,9 @@ async def mock_load_index(
527529

528530
monkeypatch.setattr(ShardingCodec, "_load_shard_index_maybe", mock_load_index)
529531

530-
byte_getter = MockByteGetterWithIndex(
531-
index_data=b"", chunk_data=None, return_none_for_chunks=True
532+
byte_getter = cast(
533+
ByteGetter,
534+
MockByteGetterWithIndex(index_data=b"", chunk_data=None, return_none_for_chunks=True),
532535
)
533536

534537
all_chunk_coords: set[tuple[int, ...]] = {(0,)}

0 commit comments

Comments
 (0)