Skip to content

Commit 63d5e87

Browse files
authored
Merge branch 'main' into fix/cache-store-byte-range
2 parents b097214 + 879e1ce commit 63d5e87

13 files changed

Lines changed: 508 additions & 50 deletions

File tree

changes/3655.bugfix.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fixed a bug in the sharding codec that prevented nested shard reads in certain cases.

changes/3702.bugfix.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Skip chunk coordinate enumeration in resize when the array is only growing, avoiding unbounded memory usage for large arrays.

changes/3708.misc.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Optimize Morton order computation with hypercube optimization, vectorized decoding, and singleton dimension removal, providing 10-45x speedup for typical chunk shapes.

changes/3712.misc.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Added benchmarks for Morton order computation in sharded arrays.

changes/3713.misc.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Vectorize get_chunk_slice for faster sharded array writes.

changes/3717.misc.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add benchmarks for Morton order computation with non-power-of-2 and near-miss shard shapes, covering both pure computation and end-to-end read/write performance.

src/zarr/codecs/sharding.py

Lines changed: 81 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,15 @@
4646
from zarr.core.indexing import (
4747
BasicIndexer,
4848
SelectorTuple,
49+
_morton_order,
50+
_morton_order_keys,
4951
c_order_iter,
5052
get_indexer,
5153
morton_order_iter,
5254
)
5355
from zarr.core.metadata.v3 import parse_codecs
5456
from zarr.registry import get_ndbuffer_class, get_pipeline_class
57+
from zarr.storage._utils import _normalize_byte_range_index
5558

5659
if TYPE_CHECKING:
5760
from collections.abc import Iterator
@@ -86,11 +89,16 @@ class _ShardingByteGetter(ByteGetter):
8689
async def get(
8790
self, prototype: BufferPrototype, byte_range: ByteRequest | None = None
8891
) -> Buffer | None:
89-
assert byte_range is None, "byte_range is not supported within shards"
9092
assert prototype == default_buffer_prototype(), (
9193
f"prototype is not supported within shards currently. diff: {prototype} != {default_buffer_prototype()}"
9294
)
93-
return self.shard_dict.get(self.chunk_coords)
95+
value = self.shard_dict.get(self.chunk_coords)
96+
if value is None:
97+
return None
98+
if byte_range is None:
99+
return value
100+
start, stop = _normalize_byte_range_index(value, byte_range)
101+
return value[start:stop]
94102

95103

96104
@dataclass(frozen=True)
@@ -138,6 +146,45 @@ def get_chunk_slice(self, chunk_coords: tuple[int, ...]) -> tuple[int, int] | No
138146
else:
139147
return (int(chunk_start), int(chunk_start + chunk_len))
140148

149+
def get_chunk_slices_vectorized(
150+
self, chunk_coords_array: npt.NDArray[np.integer[Any]]
151+
) -> tuple[npt.NDArray[np.uint64], npt.NDArray[np.uint64], npt.NDArray[np.bool_]]:
152+
"""Get chunk slices for multiple coordinates at once.
153+
154+
Parameters
155+
----------
156+
chunk_coords_array : ndarray of shape (n_chunks, n_dims)
157+
Array of chunk coordinates to look up.
158+
159+
Returns
160+
-------
161+
starts : ndarray of shape (n_chunks,)
162+
Start byte positions for each chunk.
163+
ends : ndarray of shape (n_chunks,)
164+
End byte positions for each chunk.
165+
valid : ndarray of shape (n_chunks,)
166+
Boolean mask indicating which chunks are non-empty.
167+
"""
168+
# Localize coordinates via modulo (vectorized)
169+
shard_shape = np.array(self.offsets_and_lengths.shape[:-1], dtype=np.uint64)
170+
localized = chunk_coords_array.astype(np.uint64) % shard_shape
171+
172+
# Build index tuple for advanced indexing
173+
index_tuple = tuple(localized[:, i] for i in range(localized.shape[1]))
174+
175+
# Fetch all offsets and lengths at once
176+
offsets_and_lengths = self.offsets_and_lengths[index_tuple]
177+
starts = offsets_and_lengths[:, 0]
178+
lengths = offsets_and_lengths[:, 1]
179+
180+
# Check for valid (non-empty) chunks
181+
valid = starts != MAX_UINT_64
182+
183+
# Compute end positions
184+
ends = starts + lengths
185+
186+
return starts, ends, valid
187+
141188
def set_chunk_slice(self, chunk_coords: tuple[int, ...], chunk_slice: slice | None) -> None:
142189
localized_chunk = self._localize_chunk(chunk_coords)
143190
if chunk_slice is None:
@@ -219,6 +266,34 @@ def __len__(self) -> int:
219266
def __iter__(self) -> Iterator[tuple[int, ...]]:
220267
return c_order_iter(self.index.offsets_and_lengths.shape[:-1])
221268

269+
def to_dict_vectorized(
270+
self,
271+
chunk_coords_array: npt.NDArray[np.integer[Any]],
272+
) -> dict[tuple[int, ...], Buffer | None]:
273+
"""Build a dict of chunk coordinates to buffers using vectorized lookup.
274+
275+
Parameters
276+
----------
277+
chunk_coords_array : ndarray of shape (n_chunks, n_dims)
278+
Array of chunk coordinates for vectorized index lookup.
279+
280+
Returns
281+
-------
282+
dict mapping chunk coordinate tuples to Buffer or None
283+
"""
284+
starts, ends, valid = self.index.get_chunk_slices_vectorized(chunk_coords_array)
285+
chunks_per_shard = tuple(self.index.offsets_and_lengths.shape[:-1])
286+
chunk_coords_keys = _morton_order_keys(chunks_per_shard)
287+
288+
result: dict[tuple[int, ...], Buffer | None] = {}
289+
for i, coords in enumerate(chunk_coords_keys):
290+
if valid[i]:
291+
result[coords] = self.buf[int(starts[i]) : int(ends[i])]
292+
else:
293+
result[coords] = None
294+
295+
return result
296+
222297

223298
@dataclass(frozen=True)
224299
class ShardingCodec(
@@ -505,7 +580,8 @@ async def _encode_partial_single(
505580
chunks_per_shard=chunks_per_shard,
506581
)
507582
shard_reader = shard_reader or _ShardReader.create_empty(chunks_per_shard)
508-
shard_dict = {k: shard_reader.get(k) for k in morton_order_iter(chunks_per_shard)}
583+
# Use vectorized lookup for better performance
584+
shard_dict = shard_reader.to_dict_vectorized(np.asarray(_morton_order(chunks_per_shard)))
509585

510586
indexer = list(
511587
get_indexer(
@@ -597,7 +673,8 @@ async def _decode_shard_index(
597673
)
598674
)
599675
)
600-
assert index_array is not None
676+
# This cannot be None because we have the bytes already
677+
index_array = cast(NDBuffer, index_array)
601678
return _ShardIndex(index_array.as_numpy_array())
602679

603680
async def _encode_shard_index(self, index: _ShardIndex) -> Buffer:

src/zarr/core/array.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5990,7 +5990,10 @@ async def _resize(
59905990
assert len(new_shape) == len(array.metadata.shape)
59915991
new_metadata = array.metadata.update_shape(new_shape)
59925992

5993-
if delete_outside_chunks:
5993+
# ensure deletion is only run if array is shrinking as the delete_outside_chunks path is unbounded in memory
5994+
only_growing = all(new >= old for new, old in zip(new_shape, array.metadata.shape, strict=True))
5995+
5996+
if delete_outside_chunks and not only_growing:
59945997
# Remove all chunks outside of the new shape
59955998
old_chunk_coords = set(array.metadata.chunk_grid.all_chunk_coords(array.metadata.shape))
59965999
new_chunk_coords = set(array.metadata.chunk_grid.all_chunk_coords(new_shape))

src/zarr/core/indexing.py

Lines changed: 95 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1452,7 +1452,7 @@ def make_slice_selection(selection: Any) -> list[slice]:
14521452
def decode_morton(z: int, chunk_shape: tuple[int, ...]) -> tuple[int, ...]:
14531453
# Inspired by compressed morton code as implemented in Neuroglancer
14541454
# https://github.com/google/neuroglancer/blob/master/src/neuroglancer/datasource/precomputed/volume.md#compressed-morton-code
1455-
bits = tuple(math.ceil(math.log2(c)) for c in chunk_shape)
1455+
bits = tuple((c - 1).bit_length() for c in chunk_shape)
14561456
max_coords_bits = max(bits)
14571457
input_bit = 0
14581458
input_value = z
@@ -1467,21 +1467,104 @@ def decode_morton(z: int, chunk_shape: tuple[int, ...]) -> tuple[int, ...]:
14671467
return tuple(out)
14681468

14691469

1470-
@lru_cache
1471-
def _morton_order(chunk_shape: tuple[int, ...]) -> tuple[tuple[int, ...], ...]:
1470+
def decode_morton_vectorized(
1471+
z: npt.NDArray[np.intp], chunk_shape: tuple[int, ...]
1472+
) -> npt.NDArray[np.intp]:
1473+
"""Vectorized Morton code decoding for multiple z values.
1474+
1475+
Parameters
1476+
----------
1477+
z : ndarray
1478+
1D array of Morton codes to decode.
1479+
chunk_shape : tuple of int
1480+
Shape defining the coordinate space.
1481+
1482+
Returns
1483+
-------
1484+
ndarray
1485+
2D array of shape (len(z), len(chunk_shape)) containing decoded coordinates.
1486+
"""
1487+
n_dims = len(chunk_shape)
1488+
bits = tuple((c - 1).bit_length() for c in chunk_shape)
1489+
1490+
max_coords_bits = max(bits) if bits else 0
1491+
out = np.zeros((len(z), n_dims), dtype=np.intp)
1492+
1493+
input_bit = 0
1494+
for coord_bit in range(max_coords_bits):
1495+
for dim in range(n_dims):
1496+
if coord_bit < bits[dim]:
1497+
# Extract bit at position input_bit from all z values
1498+
bit_values = (z >> input_bit) & 1
1499+
# Place bit at coord_bit position in dimension dim
1500+
out[:, dim] |= bit_values << coord_bit
1501+
input_bit += 1
1502+
1503+
return out
1504+
1505+
1506+
@lru_cache(maxsize=16)
1507+
def _morton_order(chunk_shape: tuple[int, ...]) -> npt.NDArray[np.intp]:
14721508
n_total = product(chunk_shape)
1473-
order: list[tuple[int, ...]] = []
1474-
i = 0
1475-
while len(order) < n_total:
1476-
m = decode_morton(i, chunk_shape)
1477-
if all(x < y for x, y in zip(m, chunk_shape, strict=False)):
1478-
order.append(m)
1479-
i += 1
1480-
return tuple(order)
1509+
n_dims = len(chunk_shape)
1510+
if n_total == 0:
1511+
out = np.empty((0, n_dims), dtype=np.intp)
1512+
out.flags.writeable = False
1513+
return out
1514+
1515+
# Ceiling hypercube: smallest power-of-2 hypercube whose Morton codes span
1516+
# all valid coordinates in chunk_shape. (c-1).bit_length() gives the number
1517+
# of bits needed to index c values (0 for singleton dims). n_z = 2**total_bits
1518+
# is the size of this hypercube.
1519+
total_bits = sum((c - 1).bit_length() for c in chunk_shape)
1520+
n_z = 1 << total_bits if total_bits > 0 else 1
1521+
1522+
# Decode all Morton codes in the ceiling hypercube, then filter to valid coords.
1523+
# This is fully vectorized. For shapes with n_z >> n_total (e.g. (33,33,33):
1524+
# n_z=262144, n_total=35937), consider the argsort strategy below.
1525+
order: npt.NDArray[np.intp]
1526+
if n_z <= 4 * n_total:
1527+
# Ceiling strategy: decode all n_z codes vectorized, filter in-bounds.
1528+
# Works well when the overgeneration ratio n_z/n_total is small (≤4).
1529+
z_values = np.arange(n_z, dtype=np.intp)
1530+
all_coords = decode_morton_vectorized(z_values, chunk_shape)
1531+
shape_arr = np.array(chunk_shape, dtype=np.intp)
1532+
valid_mask = np.all(all_coords < shape_arr, axis=1)
1533+
order = all_coords[valid_mask]
1534+
else:
1535+
# Argsort strategy: enumerate all n_total valid coordinates directly,
1536+
# encode each to a Morton code, then sort by code. Avoids the 8x or
1537+
# larger overgeneration penalty for near-miss shapes like (33,33,33).
1538+
# Cost: O(n_total * bits) encode + O(n_total log n_total) sort,
1539+
# vs O(n_z * bits) = O(8 * n_total * bits) for ceiling.
1540+
grids = np.meshgrid(*[np.arange(c, dtype=np.intp) for c in chunk_shape], indexing="ij")
1541+
all_coords = np.stack([g.ravel() for g in grids], axis=1)
1542+
1543+
# Encode all coordinates to Morton codes (vectorized).
1544+
bits_per_dim = tuple((c - 1).bit_length() for c in chunk_shape)
1545+
max_coord_bits = max(bits_per_dim)
1546+
z_codes = np.zeros(n_total, dtype=np.intp)
1547+
output_bit = 0
1548+
for coord_bit in range(max_coord_bits):
1549+
for dim in range(n_dims):
1550+
if coord_bit < bits_per_dim[dim]:
1551+
z_codes |= ((all_coords[:, dim] >> coord_bit) & 1) << output_bit
1552+
output_bit += 1
1553+
1554+
sort_idx: npt.NDArray[np.intp] = np.argsort(z_codes, kind="stable")
1555+
order = np.asarray(all_coords[sort_idx], dtype=np.intp)
1556+
1557+
order.flags.writeable = False
1558+
return order
1559+
1560+
1561+
@lru_cache(maxsize=16)
1562+
def _morton_order_keys(chunk_shape: tuple[int, ...]) -> tuple[tuple[int, ...], ...]:
1563+
return tuple(tuple(int(x) for x in row) for row in _morton_order(chunk_shape))
14811564

14821565

14831566
def morton_order_iter(chunk_shape: tuple[int, ...]) -> Iterator[tuple[int, ...]]:
1484-
return iter(_morton_order(tuple(chunk_shape)))
1567+
return iter(_morton_order_keys(tuple(chunk_shape)))
14851568

14861569

14871570
def c_order_iter(chunks_per_shard: tuple[int, ...]) -> Iterator[tuple[int, ...]]:

src/zarr/testing/stateful.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -340,13 +340,13 @@ def delete_array_using_del(self, data: DataObject) -> None:
340340
self.all_arrays.remove(array_path)
341341

342342
@precondition(lambda self: self.store.supports_deletes)
343-
@precondition(lambda self: len(self.all_groups) >= 2) # fixme don't delete root
343+
@precondition(lambda self: bool(self.all_groups))
344344
@rule(data=st.data())
345345
def delete_group_using_del(self, data: DataObject) -> None:
346-
# ensure that we don't include the root group in the list of member names that we try
347-
# to delete
348-
member_names = tuple(filter(lambda v: "/" in v, sorted(self.all_groups)))
349-
group_path = data.draw(st.sampled_from(member_names), label="Group deletion target")
346+
group_path = data.draw(
347+
st.sampled_from(sorted(self.all_groups)),
348+
label="Group deletion target",
349+
)
350350
prefix, group_name = split_prefix_name(group_path)
351351
note(f"Deleting group '{group_path=!r}', {prefix=!r}, {group_name=!r} using delete")
352352
members = zarr.open_group(store=self.model, path=group_path).members(max_depth=None)
@@ -359,9 +359,7 @@ def delete_group_using_del(self, data: DataObject) -> None:
359359
group = zarr.open_group(store=store, path=prefix)
360360
group[group_name] # check that it exists
361361
del group[group_name]
362-
if group_path != "/":
363-
# The root group is always present
364-
self.all_groups.remove(group_path)
362+
self.all_groups.remove(group_path)
365363

366364
# # --------------- assertions -----------------
367365
# def check_group_arrays(self, group):

0 commit comments

Comments
 (0)