Skip to content

Commit 332ddc2

Browse files
authored
Merge branch 'main' into fill-missing-chunks
2 parents df73f2a + 1bfa53f commit 332ddc2

4 files changed

Lines changed: 60 additions & 3 deletions

File tree

docs/_static/favicon-96x96.png

12.4 KB
Loading

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ theme:
8484
name: material
8585
custom_dir: docs/overrides
8686
logo: _static/logo_bw.png
87+
favicon: _static/favicon-96x96.png
8788

8889
palette:
8990
# Light mode

src/zarr/core/codec_pipeline.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,8 @@ async def read_batch(
265265
chunk_array_batch, batch_info, strict=False
266266
):
267267
if chunk_array is not None:
268+
if drop_axes:
269+
chunk_array = chunk_array.squeeze(axis=drop_axes)
268270
out[out_selection] = chunk_array
269271
elif chunk_spec.config.fill_missing_chunks:
270272
out[out_selection] = fill_value_or_default(chunk_spec)
@@ -297,7 +299,7 @@ async def read_batch(
297299
) in zip(chunk_array_batch, batch_info, strict=False):
298300
if chunk_array is not None:
299301
tmp = chunk_array[chunk_selection]
300-
if drop_axes != ():
302+
if drop_axes:
301303
tmp = tmp.squeeze(axis=drop_axes)
302304
out[out_selection] = tmp
303305
elif chunk_spec.config.fill_missing_chunks:
@@ -340,7 +342,7 @@ def _merge_chunk_array(
340342
else:
341343
chunk_value = value[out_selection]
342344
# handle missing singleton dimensions
343-
if drop_axes != ():
345+
if drop_axes:
344346
item = tuple(
345347
None # equivalent to np.newaxis
346348
if idx in drop_axes

tests/test_codecs/test_sharding.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,8 @@ def test_invalid_shard_shape() -> None:
490490
with pytest.raises(
491491
ValueError,
492492
match=re.escape(
493-
"The array's `chunk_shape` (got (16, 16)) needs to be divisible by the shard's inner `chunk_shape` (got (9,))."
493+
"The array's `chunk_shape` (got (16, 16)) needs to be divisible "
494+
"by the shard's inner `chunk_shape` (got (9,))."
494495
),
495496
):
496497
zarr.create_array(
@@ -501,3 +502,56 @@ def test_invalid_shard_shape() -> None:
501502
dtype=np.dtype("uint8"),
502503
fill_value=0,
503504
)
505+
506+
507+
@pytest.mark.parametrize("store", ["local"], indirect=["store"])
508+
def test_sharding_mixed_integer_list_indexing(store: Store) -> None:
509+
"""Regression test for https://github.com/zarr-developers/zarr-python/issues/3691.
510+
511+
Mixed integer/list indexing on sharded arrays should return the same
512+
shape and data as on equivalent chunked arrays.
513+
"""
514+
import numpy as np
515+
516+
data = np.arange(200 * 100 * 10, dtype=np.uint8).reshape(200, 100, 10)
517+
518+
chunked = zarr.create_array(
519+
store,
520+
name="chunked",
521+
shape=(200, 100, 10),
522+
dtype=np.uint8,
523+
chunks=(200, 100, 1),
524+
overwrite=True,
525+
)
526+
chunked[:, :, :] = data
527+
528+
sharded = zarr.create_array(
529+
store,
530+
name="sharded",
531+
shape=(200, 100, 10),
532+
dtype=np.uint8,
533+
chunks=(200, 100, 1),
534+
shards=(200, 100, 10),
535+
overwrite=True,
536+
)
537+
sharded[:, :, :] = data
538+
539+
# Mixed integer + list indexing
540+
c = chunked[0:10, 0, [0, 1]] # type: ignore[index]
541+
s = sharded[0:10, 0, [0, 1]] # type: ignore[index]
542+
assert c.shape == s.shape == (10, 2), ( # type: ignore[union-attr]
543+
f"Expected (10, 2), got chunked={c.shape}, sharded={s.shape}" # type: ignore[union-attr]
544+
)
545+
np.testing.assert_array_equal(c, s)
546+
547+
# Multiple integer axes
548+
c2 = chunked[0, 0, [0, 1, 2]] # type: ignore[index]
549+
s2 = sharded[0, 0, [0, 1, 2]] # type: ignore[index]
550+
assert c2.shape == s2.shape == (3,) # type: ignore[union-attr]
551+
np.testing.assert_array_equal(c2, s2)
552+
553+
# Slice + integer + slice
554+
c3 = chunked[0:5, 1, 0:3]
555+
s3 = sharded[0:5, 1, 0:3]
556+
assert c3.shape == s3.shape == (5, 3) # type: ignore[union-attr]
557+
np.testing.assert_array_equal(c3, s3)

0 commit comments

Comments
 (0)