Skip to content

Commit 93dbf78

Browse files
abishop1990Cipherd-v-b
authored
fix: apply drop_axes squeeze in partial decode path for sharding (#3691) (#3763)
* fix: apply drop_axes squeeze in partial decode path for sharding When reading sharded arrays with mixed integer/list indexing (e.g. arr[0:10, 0, [0, 1]]), the outer OrthogonalIndexer produces chunk selections that have been ix_()-transformed for orthogonal advanced indexing. Integer indices become single-element ranges (size-1 dims) via ix_() to enable NumPy orthogonal indexing. In CodecPipeline.read_batch(), the non-partial path correctly applies drop_axes.squeeze() to remove those size-1 integer dimensions before writing to the output buffer. However, the partial decode path (used by ShardingCodec) was missing this squeeze step. Fixes #3691 Also: Fix line length violation in test error message to comply with 100 character linting limit. * fix(mypy): add type ignore comments for dynamic array indexing in sharding test The test uses complex indexing patterns (mixed integer/list indices) that mypy's zarr.Array stubs don't recognize as valid. Add specific type ignore comments for [index] and [union-attr] errors to suppress false positives. * fix(mypy): correct type-ignore codes for union attribute access in sharding test - Line 542: Fix assert accessing .shape by changing from [index] to [union-attr] - Line 544: Add missing type-ignore[union-attr] for f-string .shape access - Lines 554-555: Remove unused type-ignore[index] comments on assignments The mypy errors were caused by indexing operations returning union types that include scalar types (int, float, etc.), which don't have a .shape attribute. The proper fix uses type-ignore[union-attr] for attribute access, not [index]. --------- Co-authored-by: Cipher <cipher@openclaw.ai> Co-authored-by: Davis Bennett <davis.v.bennett@gmail.com>
1 parent a02d996 commit 93dbf78

File tree

2 files changed

+59
-3
lines changed

2 files changed

+59
-3
lines changed

src/zarr/core/codec_pipeline.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,8 @@ async def read_batch(
263263
chunk_array_batch, batch_info, strict=False
264264
):
265265
if chunk_array is not None:
266+
if drop_axes:
267+
chunk_array = chunk_array.squeeze(axis=drop_axes)
266268
out[out_selection] = chunk_array
267269
else:
268270
out[out_selection] = fill_value_or_default(chunk_spec)
@@ -285,7 +287,7 @@ async def read_batch(
285287
):
286288
if chunk_array is not None:
287289
tmp = chunk_array[chunk_selection]
288-
if drop_axes != ():
290+
if drop_axes:
289291
tmp = tmp.squeeze(axis=drop_axes)
290292
out[out_selection] = tmp
291293
else:
@@ -324,7 +326,7 @@ def _merge_chunk_array(
324326
else:
325327
chunk_value = value[out_selection]
326328
# handle missing singleton dimensions
327-
if drop_axes != ():
329+
if drop_axes:
328330
item = tuple(
329331
None # equivalent to np.newaxis
330332
if idx in drop_axes

tests/test_codecs/test_sharding.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,8 @@ def test_invalid_shard_shape() -> None:
490490
with pytest.raises(
491491
ValueError,
492492
match=re.escape(
493-
"The array's `chunk_shape` (got (16, 16)) needs to be divisible by the shard's inner `chunk_shape` (got (9,))."
493+
"The array's `chunk_shape` (got (16, 16)) needs to be divisible "
494+
"by the shard's inner `chunk_shape` (got (9,))."
494495
),
495496
):
496497
zarr.create_array(
@@ -501,3 +502,56 @@ def test_invalid_shard_shape() -> None:
501502
dtype=np.dtype("uint8"),
502503
fill_value=0,
503504
)
505+
506+
507+
@pytest.mark.parametrize("store", ["local"], indirect=["store"])
508+
def test_sharding_mixed_integer_list_indexing(store: Store) -> None:
509+
"""Regression test for https://github.com/zarr-developers/zarr-python/issues/3691.
510+
511+
Mixed integer/list indexing on sharded arrays should return the same
512+
shape and data as on equivalent chunked arrays.
513+
"""
514+
import numpy as np
515+
516+
data = np.arange(200 * 100 * 10, dtype=np.uint8).reshape(200, 100, 10)
517+
518+
chunked = zarr.create_array(
519+
store,
520+
name="chunked",
521+
shape=(200, 100, 10),
522+
dtype=np.uint8,
523+
chunks=(200, 100, 1),
524+
overwrite=True,
525+
)
526+
chunked[:, :, :] = data
527+
528+
sharded = zarr.create_array(
529+
store,
530+
name="sharded",
531+
shape=(200, 100, 10),
532+
dtype=np.uint8,
533+
chunks=(200, 100, 1),
534+
shards=(200, 100, 10),
535+
overwrite=True,
536+
)
537+
sharded[:, :, :] = data
538+
539+
# Mixed integer + list indexing
540+
c = chunked[0:10, 0, [0, 1]] # type: ignore[index]
541+
s = sharded[0:10, 0, [0, 1]] # type: ignore[index]
542+
assert c.shape == s.shape == (10, 2), ( # type: ignore[union-attr]
543+
f"Expected (10, 2), got chunked={c.shape}, sharded={s.shape}" # type: ignore[union-attr]
544+
)
545+
np.testing.assert_array_equal(c, s)
546+
547+
# Multiple integer axes
548+
c2 = chunked[0, 0, [0, 1, 2]] # type: ignore[index]
549+
s2 = sharded[0, 0, [0, 1, 2]] # type: ignore[index]
550+
assert c2.shape == s2.shape == (3,) # type: ignore[union-attr]
551+
np.testing.assert_array_equal(c2, s2)
552+
553+
# Slice + integer + slice
554+
c3 = chunked[0:5, 1, 0:3]
555+
s3 = sharded[0:5, 1, 0:3]
556+
assert c3.shape == s3.shape == (5, 3) # type: ignore[union-attr]
557+
np.testing.assert_array_equal(c3, s3)

0 commit comments

Comments
 (0)