Skip to content

Commit f8b3d38

Browse files
authored
fix/nested shard reads (#3655)
* fix partial nested shard reads * changelog
1 parent d45f846 commit f8b3d38

File tree

3 files changed

+33
-28
lines changed

3 files changed

+33
-28
lines changed

changes/3655.bugfix.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fixed a bug in the sharding codec that prevented nested shard reads in certain cases.

src/zarr/codecs/sharding.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
)
5353
from zarr.core.metadata.v3 import parse_codecs
5454
from zarr.registry import get_ndbuffer_class, get_pipeline_class
55+
from zarr.storage._utils import _normalize_byte_range_index
5556

5657
if TYPE_CHECKING:
5758
from collections.abc import Iterator
@@ -86,11 +87,16 @@ class _ShardingByteGetter(ByteGetter):
8687
async def get(
8788
self, prototype: BufferPrototype, byte_range: ByteRequest | None = None
8889
) -> Buffer | None:
89-
assert byte_range is None, "byte_range is not supported within shards"
9090
assert prototype == default_buffer_prototype(), (
9191
f"prototype is not supported within shards currently. diff: {prototype} != {default_buffer_prototype()}"
9292
)
93-
return self.shard_dict.get(self.chunk_coords)
93+
value = self.shard_dict.get(self.chunk_coords)
94+
if value is None:
95+
return None
96+
if byte_range is None:
97+
return value
98+
start, stop = _normalize_byte_range_index(value, byte_range)
99+
return value[start:stop]
94100

95101

96102
@dataclass(frozen=True)
@@ -597,7 +603,8 @@ async def _decode_shard_index(
597603
)
598604
)
599605
)
600-
assert index_array is not None
606+
# This cannot be None because we have the bytes already
607+
index_array = cast(NDBuffer, index_array)
601608
return _ShardIndex(index_array.as_numpy_array())
602609

603610
async def _encode_shard_index(self, index: _ShardIndex) -> Buffer:

tests/test_codecs/test_sharding.py

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
TransposeCodec,
1919
)
2020
from zarr.core.buffer import NDArrayLike, default_buffer_prototype
21-
from zarr.errors import ZarrUserWarning
2221
from zarr.storage import StorePath, ZipStore
2322

2423
from ..conftest import ArrayRequest
@@ -239,12 +238,14 @@ def test_sharding_partial_overwrite(
239238
assert np.array_equal(data, read_data)
240239

241240

241+
# Zip storage raises a warning about a duplicate name, which we ignore.
242+
@pytest.mark.filterwarnings("ignore:Duplicate name.*:UserWarning")
242243
@pytest.mark.parametrize(
243244
"array_fixture",
244245
[
245-
ArrayRequest(shape=(128,) * 3, dtype="uint16", order="F"),
246+
ArrayRequest(shape=(127, 128, 129), dtype="uint16", order="F"),
246247
],
247-
indirect=["array_fixture"],
248+
indirect=True,
248249
)
249250
@pytest.mark.parametrize(
250251
"outer_index_location",
@@ -263,24 +264,23 @@ def test_nested_sharding(
263264
) -> None:
264265
data = array_fixture
265266
spath = StorePath(store)
266-
msg = "Combining a `sharding_indexed` codec disables partial reads and writes, which may lead to inefficient performance."
267-
with pytest.warns(ZarrUserWarning, match=msg):
268-
a = zarr.create_array(
269-
spath,
270-
shape=data.shape,
271-
chunks=(64, 64, 64),
272-
dtype=data.dtype,
273-
fill_value=0,
274-
serializer=ShardingCodec(
275-
chunk_shape=(32, 32, 32),
276-
codecs=[
277-
ShardingCodec(chunk_shape=(16, 16, 16), index_location=inner_index_location)
278-
],
279-
index_location=outer_index_location,
280-
),
281-
)
267+
# compressors=None ensures no BytesBytesCodec is added, which keeps
268+
# supports_partial_decode=True and exercises the partial decode path
269+
a = zarr.create_array(
270+
spath,
271+
data=data,
272+
chunks=(64,) * data.ndim,
273+
compressors=None,
274+
serializer=ShardingCodec(
275+
chunk_shape=(32,) * data.ndim,
276+
codecs=[
277+
ShardingCodec(chunk_shape=(16,) * data.ndim, index_location=inner_index_location)
278+
],
279+
index_location=outer_index_location,
280+
),
281+
)
282282

283-
a[:, :, :] = data
283+
a[:] = data
284284

285285
read_data = a[0 : data.shape[0], 0 : data.shape[1], 0 : data.shape[2]]
286286
assert isinstance(read_data, NDArrayLike)
@@ -326,13 +326,10 @@ def test_nested_sharding_create_array(
326326
filters=None,
327327
compressors=None,
328328
)
329-
print(a.metadata.to_dict())
330329

331-
a[:, :, :] = data
330+
a[:] = data
332331

333-
read_data = a[0 : data.shape[0], 0 : data.shape[1], 0 : data.shape[2]]
334-
assert isinstance(read_data, NDArrayLike)
335-
assert data.shape == read_data.shape
332+
read_data = a[:]
336333
assert np.array_equal(data, read_data)
337334

338335

0 commit comments

Comments
 (0)