Skip to content

Commit 51033fb

Browse files
d-v-bclaude
andauthored
refactor(sharding): store chunks_per_shard explicitly in _ShardIndex (zarr-developers#3975)
* refactor(sharding): store chunks_per_shard explicitly in _ShardIndex _ShardIndex previously inferred the chunk grid shape from offsets_and_lengths.shape[:-1]. For 0-D arrays this collapses the array to rank-1, breaking methods that assume rank >= 2 and forcing a numpy compat cast workaround. Store chunks_per_shard as an explicit NamedTuple field instead. This removes the chunks_per_shard property and its cast, and lets several call sites use the field directly instead of reverse-engineering it. Also fix a latent 0-D bug in is_dense, which iterated offsets_and_lengths assuming rank-2. Closes zarr-developers#3974 Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * test: parametrize over dimensionalities * refactor(sharding): store chunks_per_shard explicitly in _ShardIndex _ShardIndex previously inferred the chunk grid shape from offsets_and_lengths.shape[:-1]. For 0-D arrays this collapses the array to rank-1, breaking methods that assume rank >= 2 and forcing a numpy compat cast workaround. Store chunks_per_shard as an explicit NamedTuple field instead. This removes the chunks_per_shard property and its cast, and lets several call sites use the field directly instead of reverse-engineering it. Also fix a latent 0-D bug in is_dense, which iterated offsets_and_lengths assuming rank-2. Tests for get_chunk_slices_vectorized, _ShardReader.__iter__, and is_dense are parametrized over chunk grid ranks (0-D, 1-D, 2-D) so 0-D is exercised as a normal case rather than a special branch. Closes zarr-developers#3974 Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * remove is_dense --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent cdb5846 commit 51033fb

3 files changed

Lines changed: 55 additions & 67 deletions

File tree

changes/3975.misc.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Store `chunks_per_shard` explicitly as a field on `_ShardIndex` instead of inferring it from `offsets_and_lengths.shape[:-1]`. The previous derivation collapsed to rank-1 for 0-D arrays, requiring a numpy-compat cast workaround that is now removed. Also removes the unused `_ShardIndex.is_dense` method, which was ported from an earlier prototype and never had any call sites or tests.

src/zarr/codecs/sharding.py

Lines changed: 24 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from dataclasses import dataclass, replace
55
from enum import Enum
66
from functools import lru_cache
7-
from operator import itemgetter
87
from typing import TYPE_CHECKING, Any, NamedTuple, cast
98

109
import numpy as np
@@ -123,19 +122,15 @@ async def set_if_not_exists(self, default: Buffer) -> None:
123122

124123

125124
class _ShardIndex(NamedTuple):
125+
# the chunk grid shape of a single shard
126+
chunks_per_shard: tuple[int, ...]
126127
# dtype uint64, shape (chunks_per_shard_0, chunks_per_shard_1, ..., 2)
127128
offsets_and_lengths: npt.NDArray[np.uint64]
128129

129-
@property
130-
def chunks_per_shard(self) -> tuple[int, ...]:
131-
result = tuple(self.offsets_and_lengths.shape[0:-1])
132-
# The cast is required until https://github.com/numpy/numpy/pull/27211 is merged
133-
return cast("tuple[int, ...]", result)
134-
135130
def _localize_chunk(self, chunk_coords: tuple[int, ...]) -> tuple[int, ...]:
136131
return tuple(
137132
chunk_i % shard_i
138-
for chunk_i, shard_i in zip(chunk_coords, self.offsets_and_lengths.shape, strict=False)
133+
for chunk_i, shard_i in zip(chunk_coords, self.chunks_per_shard, strict=False)
139134
)
140135

141136
def is_all_empty(self) -> bool:
@@ -171,25 +166,24 @@ def get_chunk_slices_vectorized(
171166
valid : ndarray of shape (n_chunks,)
172167
Boolean mask indicating which chunks are non-empty.
173168
"""
174-
# Handle 0-dimensional arrays (n_dims == 0)
169+
# Handle 0-dimensional arrays (n_dims == 0): the shard holds a single
170+
# chunk, so every coordinate maps to the same flat entry.
175171
if chunk_coords_array.shape[1] == 0:
176-
# offsets_and_lengths has shape (2,) for 0D, reshape to (1, 2)
177-
offsets_and_lengths = self.offsets_and_lengths.reshape(1, 2)
178-
starts = offsets_and_lengths[:, 0]
179-
lengths = offsets_and_lengths[:, 1]
180-
valid = starts != MAX_UINT_64
181-
ends = starts + lengths
182-
return starts, ends, valid
183-
184-
# Localize coordinates via modulo (vectorized)
185-
shard_shape = np.array(self.offsets_and_lengths.shape[:-1], dtype=np.uint64)
186-
localized = chunk_coords_array.astype(np.uint64) % shard_shape
187-
188-
# Build index tuple for advanced indexing
189-
index_tuple = tuple(localized[:, i] for i in range(localized.shape[1]))
190-
191-
# Fetch all offsets and lengths at once
192-
offsets_and_lengths = self.offsets_and_lengths[index_tuple]
172+
offsets_and_lengths = self.offsets_and_lengths.reshape(-1, 2)
173+
offsets_and_lengths = np.broadcast_to(
174+
offsets_and_lengths, (chunk_coords_array.shape[0], 2)
175+
)
176+
else:
177+
# Localize coordinates via modulo (vectorized)
178+
shard_shape = np.array(self.chunks_per_shard, dtype=np.uint64)
179+
localized = chunk_coords_array.astype(np.uint64) % shard_shape
180+
181+
# Build index tuple for advanced indexing
182+
index_tuple = tuple(localized[:, i] for i in range(localized.shape[1]))
183+
184+
# Fetch all offsets and lengths at once
185+
offsets_and_lengths = self.offsets_and_lengths[index_tuple]
186+
193187
starts = offsets_and_lengths[:, 0]
194188
lengths = offsets_and_lengths[:, 1]
195189

@@ -211,32 +205,11 @@ def set_chunk_slice(self, chunk_coords: tuple[int, ...], chunk_slice: slice | No
211205
chunk_slice.stop - chunk_slice.start,
212206
)
213207

214-
def is_dense(self, chunk_byte_length: int) -> bool:
215-
sorted_offsets_and_lengths = sorted(
216-
[
217-
(offset, length)
218-
for offset, length in self.offsets_and_lengths
219-
if offset != MAX_UINT_64
220-
],
221-
key=itemgetter(0),
222-
)
223-
224-
# Are all non-empty offsets unique?
225-
if len(
226-
{offset for offset, _ in sorted_offsets_and_lengths if offset != MAX_UINT_64}
227-
) != len(sorted_offsets_and_lengths):
228-
return False
229-
230-
return all(
231-
offset % chunk_byte_length == 0 and length == chunk_byte_length
232-
for offset, length in sorted_offsets_and_lengths
233-
)
234-
235208
@classmethod
236209
def create_empty(cls, chunks_per_shard: tuple[int, ...]) -> _ShardIndex:
237210
offsets_and_lengths = np.zeros(chunks_per_shard + (2,), dtype="<u8", order="C")
238211
offsets_and_lengths.fill(MAX_UINT_64)
239-
return cls(offsets_and_lengths)
212+
return cls(chunks_per_shard, offsets_and_lengths)
240213

241214

242215
class _ShardReader(ShardMapping):
@@ -280,7 +253,7 @@ def __len__(self) -> int:
280253
return int(self.index.offsets_and_lengths.size / 2)
281254

282255
def __iter__(self) -> Iterator[tuple[int, ...]]:
283-
return c_order_iter(self.index.offsets_and_lengths.shape[:-1])
256+
return c_order_iter(self.index.chunks_per_shard)
284257

285258
def to_dict_vectorized(
286259
self,
@@ -298,8 +271,7 @@ def to_dict_vectorized(
298271
dict mapping chunk coordinate tuples to Buffer or None
299272
"""
300273
starts, ends, valid = self.index.get_chunk_slices_vectorized(chunk_coords_array)
301-
chunks_per_shard = tuple(self.index.offsets_and_lengths.shape[:-1])
302-
chunk_coords_keys = _morton_order_keys(chunks_per_shard)
274+
chunk_coords_keys = _morton_order_keys(self.index.chunks_per_shard)
303275

304276
result: dict[tuple[int, ...], Buffer | None] = {}
305277
for i, coords in enumerate(chunk_coords_keys):
@@ -712,7 +684,7 @@ async def _decode_shard_index(
712684
)
713685
# This cannot be None because we have the bytes already
714686
index_array = cast(NDBuffer, index_array)
715-
return _ShardIndex(index_array.as_numpy_array())
687+
return _ShardIndex(chunks_per_shard, index_array.as_numpy_array())
716688

717689
async def _encode_shard_index(self, index: _ShardIndex) -> Buffer:
718690
index_bytes = next(

tests/test_codecs/test_sharding.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
)
1919
from zarr.codecs.sharding import MAX_UINT_64, _ShardIndex
2020
from zarr.core.buffer import NDArrayLike, default_buffer_prototype
21+
from zarr.core.indexing import c_order_iter
2122
from zarr.storage import StorePath, ZipStore
2223

2324
from ..conftest import ArrayRequest
@@ -567,18 +568,32 @@ def test_sharding_zero_dimensional() -> None:
567568
assert arr[()] == pytest.approx(43.0)
568569

569570

570-
def test_shard_index_get_chunk_slices_vectorized_zero_dimensional() -> None:
571-
"""Directly cover the 0-D path in _ShardIndex.get_chunk_slices_vectorized."""
572-
# For a 0D array offsets_and_lengths has shape (2,) — reshape to (1, 2) inside.
573-
index = _ShardIndex(np.array([10, 4], dtype=np.uint64))
574-
chunk_coords = np.empty((1, 0), dtype=np.uint64)
575-
starts, ends, valid = index.get_chunk_slices_vectorized(chunk_coords)
576-
np.testing.assert_array_equal(starts, np.array([10], dtype=np.uint64))
577-
np.testing.assert_array_equal(ends, np.array([14], dtype=np.uint64))
578-
np.testing.assert_array_equal(valid, np.array([True]))
579-
580-
# Empty/unwritten chunk case
581-
index_empty = _ShardIndex(np.array([MAX_UINT_64, MAX_UINT_64], dtype=np.uint64))
582-
starts_e, _ends_e, valid_e = index_empty.get_chunk_slices_vectorized(chunk_coords)
583-
np.testing.assert_array_equal(starts_e, np.array([MAX_UINT_64], dtype=np.uint64))
584-
np.testing.assert_array_equal(valid_e, np.array([False]))
571+
def test_shard_index_stores_chunks_per_shard_explicitly() -> None:
572+
"""_ShardIndex stores the chunk grid shape as an explicit field."""
573+
index = _ShardIndex.create_empty((2, 3))
574+
assert index.chunks_per_shard == (2, 3)
575+
576+
# 0-D: chunks_per_shard is the empty tuple, distinct from the array's rank
577+
index_0d = _ShardIndex.create_empty(())
578+
assert index_0d.chunks_per_shard == ()
579+
580+
581+
@pytest.mark.parametrize("chunks_per_shard", [(), (3,), (2, 3)])
582+
def test_shard_index_get_chunk_slices_vectorized(chunks_per_shard: tuple[int, ...]) -> None:
583+
"""get_chunk_slices_vectorized works uniformly across chunk grid ranks, including 0-D."""
584+
index = _ShardIndex.create_empty(chunks_per_shard)
585+
# Write the first chunk; leave the rest (if any) empty.
586+
all_coords = list(c_order_iter(chunks_per_shard))
587+
index.set_chunk_slice(all_coords[0], slice(10, 14))
588+
589+
coords_array = np.array(all_coords, dtype=np.uint64).reshape(
590+
len(all_coords), len(chunks_per_shard)
591+
)
592+
starts, ends, valid = index.get_chunk_slices_vectorized(coords_array)
593+
594+
expected_valid = np.zeros(len(all_coords), dtype=bool)
595+
expected_valid[0] = True
596+
np.testing.assert_array_equal(valid, expected_valid)
597+
assert starts[0] == 10
598+
assert ends[0] == 14
599+
np.testing.assert_array_equal(starts[~expected_valid], MAX_UINT_64)

0 commit comments

Comments
 (0)