Skip to content

Commit 47a0e23

Browse files
author
Cipher
committed
fix: apply drop_axes squeeze in partial decode path for sharding (#3691)
When reading sharded arrays with mixed integer/list indexing (e.g. arr[0:10, 0, [0, 1]]), the outer OrthogonalIndexer produces chunk selections that have been ix_()-transformed for orthogonal advanced indexing. Integer indices become single-element ranges (size-1 dims) via ix_() to enable NumPy orthogonal indexing. In CodecPipeline.read_batch(), the non-partial path correctly applies drop_axes.squeeze() to remove those size-1 integer dimensions before writing to the output buffer. However, the partial decode path (used by ShardingCodec) was missing this squeeze step. The ShardingCodec._decode_partial_single() receives the ix_-transformed chunk selection, interprets it as coordinate (fancy) indexing, and returns an array with shape (10, 1, 2) instead of (10, 2). The subsequent assignment to out[out_selection] then fails with: ValueError: could not broadcast input array from shape (10,1,2) into shape (10,2) Fix: apply drop_axes squeeze to chunk_array in the partial decode branch of read_batch(), matching the behaviour of the non-partial path. Fixes #3691
1 parent 1cb1cce commit 47a0e23

File tree

2 files changed

+53
-0
lines changed

2 files changed

+53
-0
lines changed

src/zarr/core/codec_pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,8 @@ async def read_batch(
263263
chunk_array_batch, batch_info, strict=False
264264
):
265265
if chunk_array is not None:
266+
if drop_axes != ():
267+
chunk_array = chunk_array.squeeze(axis=drop_axes)
266268
out[out_selection] = chunk_array
267269
else:
268270
out[out_selection] = fill_value_or_default(chunk_spec)

tests/test_codecs/test_sharding.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,3 +501,54 @@ def test_invalid_shard_shape() -> None:
501501
dtype=np.dtype("uint8"),
502502
fill_value=0,
503503
)
504+
505+
506+
@pytest.mark.parametrize("store", ["local"], indirect=["store"])
507+
def test_sharding_mixed_integer_list_indexing(store: Store) -> None:
508+
"""Regression test for https://github.com/zarr-developers/zarr-python/issues/3691.
509+
510+
Mixed integer/list indexing on sharded arrays should return the same
511+
shape and data as on equivalent chunked arrays.
512+
"""
513+
import numpy as np
514+
515+
data = np.arange(200 * 100 * 10, dtype=np.uint8).reshape(200, 100, 10)
516+
517+
chunked = zarr.create_array(
518+
store,
519+
name="chunked",
520+
shape=(200, 100, 10),
521+
dtype=np.uint8,
522+
chunks=(200, 100, 1),
523+
overwrite=True,
524+
)
525+
chunked[:, :, :] = data
526+
527+
sharded = zarr.create_array(
528+
store,
529+
name="sharded",
530+
shape=(200, 100, 10),
531+
dtype=np.uint8,
532+
chunks=(200, 100, 1),
533+
shards=(200, 100, 10),
534+
overwrite=True,
535+
)
536+
sharded[:, :, :] = data
537+
538+
# Mixed integer + list indexing
539+
c = chunked[0:10, 0, [0, 1]]
540+
s = sharded[0:10, 0, [0, 1]]
541+
assert c.shape == s.shape == (10, 2), f"Expected (10, 2), got chunked={c.shape}, sharded={s.shape}"
542+
np.testing.assert_array_equal(c, s)
543+
544+
# Multiple integer axes
545+
c2 = chunked[0, 0, [0, 1, 2]]
546+
s2 = sharded[0, 0, [0, 1, 2]]
547+
assert c2.shape == s2.shape == (3,)
548+
np.testing.assert_array_equal(c2, s2)
549+
550+
# Slice + integer + slice
551+
c3 = chunked[0:5, 1, 0:3]
552+
s3 = sharded[0:5, 1, 0:3]
553+
assert c3.shape == s3.shape == (5, 3)
554+
np.testing.assert_array_equal(c3, s3)

0 commit comments

Comments
 (0)