Skip to content

Commit 7039de9

Browse files
authored
Merge branch 'main' into coalesce-shard-reads
2 parents c9875d5 + 32c7ab9 commit 7039de9

13 files changed

Lines changed: 511 additions & 46 deletions

File tree

changes/3655.bugfix.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fixed a bug in the sharding codec that prevented nested shard reads in certain cases.

changes/3702.bugfix.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Skip chunk coordinate enumeration in resize when the array is only growing, avoiding unbounded memory usage for large arrays.

changes/3708.misc.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Optimize Morton order computation with hypercube optimization, vectorized decoding, and singleton dimension removal, providing 10-45x speedup for typical chunk shapes.

changes/3712.misc.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Added benchmarks for Morton order computation in sharded arrays.

changes/3713.misc.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Vectorize get_chunk_slice for faster sharded array writes.

changes/3717.misc.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add benchmarks for Morton order computation with non-power-of-2 and near-miss shard shapes, covering both pure computation and end-to-end read/write performance.

src/zarr/codecs/sharding.py

Lines changed: 81 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,15 @@
4848
from zarr.core.indexing import (
4949
BasicIndexer,
5050
SelectorTuple,
51+
_morton_order,
52+
_morton_order_keys,
5153
c_order_iter,
5254
get_indexer,
5355
morton_order_iter,
5456
)
5557
from zarr.core.metadata.v3 import parse_codecs
5658
from zarr.registry import get_ndbuffer_class, get_pipeline_class
59+
from zarr.storage._utils import _normalize_byte_range_index
5760

5861
if TYPE_CHECKING:
5962
from collections.abc import Iterator
@@ -88,11 +91,16 @@ class _ShardingByteGetter(ByteGetter):
8891
async def get(
8992
self, prototype: BufferPrototype, byte_range: ByteRequest | None = None
9093
) -> Buffer | None:
91-
assert byte_range is None, "byte_range is not supported within shards"
9294
assert prototype == default_buffer_prototype(), (
9395
f"prototype is not supported within shards currently. diff: {prototype} != {default_buffer_prototype()}"
9496
)
95-
return self.shard_dict.get(self.chunk_coords)
97+
value = self.shard_dict.get(self.chunk_coords)
98+
if value is None:
99+
return None
100+
if byte_range is None:
101+
return value
102+
start, stop = _normalize_byte_range_index(value, byte_range)
103+
return value[start:stop]
96104

97105

98106
@dataclass(frozen=True)
@@ -138,6 +146,45 @@ def get_chunk_slice(self, chunk_coords: tuple[int, ...]) -> tuple[int, int] | No
138146
else:
139147
return (int(chunk_start), int(chunk_start + chunk_len))
140148

149+
def get_chunk_slices_vectorized(
150+
self, chunk_coords_array: npt.NDArray[np.integer[Any]]
151+
) -> tuple[npt.NDArray[np.uint64], npt.NDArray[np.uint64], npt.NDArray[np.bool_]]:
152+
"""Get chunk slices for multiple coordinates at once.
153+
154+
Parameters
155+
----------
156+
chunk_coords_array : ndarray of shape (n_chunks, n_dims)
157+
Array of chunk coordinates to look up.
158+
159+
Returns
160+
-------
161+
starts : ndarray of shape (n_chunks,)
162+
Start byte positions for each chunk.
163+
ends : ndarray of shape (n_chunks,)
164+
End byte positions for each chunk.
165+
valid : ndarray of shape (n_chunks,)
166+
Boolean mask indicating which chunks are non-empty.
167+
"""
168+
# Localize coordinates via modulo (vectorized)
169+
shard_shape = np.array(self.offsets_and_lengths.shape[:-1], dtype=np.uint64)
170+
localized = chunk_coords_array.astype(np.uint64) % shard_shape
171+
172+
# Build index tuple for advanced indexing
173+
index_tuple = tuple(localized[:, i] for i in range(localized.shape[1]))
174+
175+
# Fetch all offsets and lengths at once
176+
offsets_and_lengths = self.offsets_and_lengths[index_tuple]
177+
starts = offsets_and_lengths[:, 0]
178+
lengths = offsets_and_lengths[:, 1]
179+
180+
# Check for valid (non-empty) chunks
181+
valid = starts != MAX_UINT_64
182+
183+
# Compute end positions
184+
ends = starts + lengths
185+
186+
return starts, ends, valid
187+
141188
def set_chunk_slice(self, chunk_coords: tuple[int, ...], chunk_slice: slice | None) -> None:
142189
localized_chunk = self._localize_chunk(chunk_coords)
143190
if chunk_slice is None:
@@ -219,6 +266,34 @@ def __len__(self) -> int:
219266
def __iter__(self) -> Iterator[tuple[int, ...]]:
220267
return c_order_iter(self.index.offsets_and_lengths.shape[:-1])
221268

269+
def to_dict_vectorized(
270+
self,
271+
chunk_coords_array: npt.NDArray[np.integer[Any]],
272+
) -> dict[tuple[int, ...], Buffer | None]:
273+
"""Build a dict of chunk coordinates to buffers using vectorized lookup.
274+
275+
Parameters
276+
----------
277+
chunk_coords_array : ndarray of shape (n_chunks, n_dims)
278+
Array of chunk coordinates for vectorized index lookup.
279+
280+
Returns
281+
-------
282+
dict mapping chunk coordinate tuples to Buffer or None
283+
"""
284+
starts, ends, valid = self.index.get_chunk_slices_vectorized(chunk_coords_array)
285+
chunks_per_shard = tuple(self.index.offsets_and_lengths.shape[:-1])
286+
chunk_coords_keys = _morton_order_keys(chunks_per_shard)
287+
288+
result: dict[tuple[int, ...], Buffer | None] = {}
289+
for i, coords in enumerate(chunk_coords_keys):
290+
if valid[i]:
291+
result[coords] = self.buf[int(starts[i]) : int(ends[i])]
292+
else:
293+
result[coords] = None
294+
295+
return result
296+
222297

223298
@dataclass(frozen=True)
224299
class _ChunkCoordsByteSlice:
@@ -514,7 +589,8 @@ async def _encode_partial_single(
514589
chunks_per_shard=chunks_per_shard,
515590
)
516591
shard_reader = shard_reader or _ShardReader.create_empty(chunks_per_shard)
517-
shard_dict = {k: shard_reader.get(k) for k in morton_order_iter(chunks_per_shard)}
592+
# Use vectorized lookup for better performance
593+
shard_dict = shard_reader.to_dict_vectorized(np.asarray(_morton_order(chunks_per_shard)))
518594

519595
indexer = list(
520596
get_indexer(
@@ -608,7 +684,8 @@ async def _decode_shard_index(
608684
)
609685
)
610686
)
611-
assert index_array is not None
687+
# This cannot be None because we have the bytes already
688+
index_array = cast(NDBuffer, index_array)
612689
return _ShardIndex(index_array.as_numpy_array())
613690

614691
async def _encode_shard_index(self, index: _ShardIndex) -> Buffer:

src/zarr/core/array.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5990,7 +5990,10 @@ async def _resize(
59905990
assert len(new_shape) == len(array.metadata.shape)
59915991
new_metadata = array.metadata.update_shape(new_shape)
59925992

5993-
if delete_outside_chunks:
5993+
# ensure deletion is only run if array is shrinking as the delete_outside_chunks path is unbounded in memory
5994+
only_growing = all(new >= old for new, old in zip(new_shape, array.metadata.shape, strict=True))
5995+
5996+
if delete_outside_chunks and not only_growing:
59945997
# Remove all chunks outside of the new shape
59955998
old_chunk_coords = set(array.metadata.chunk_grid.all_chunk_coords(array.metadata.shape))
59965999
new_chunk_coords = set(array.metadata.chunk_grid.all_chunk_coords(new_shape))

src/zarr/core/indexing.py

Lines changed: 98 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1452,7 +1452,7 @@ def make_slice_selection(selection: Any) -> list[slice]:
14521452
def decode_morton(z: int, chunk_shape: tuple[int, ...]) -> tuple[int, ...]:
14531453
# Inspired by compressed morton code as implemented in Neuroglancer
14541454
# https://github.com/google/neuroglancer/blob/master/src/neuroglancer/datasource/precomputed/volume.md#compressed-morton-code
1455-
bits = tuple(math.ceil(math.log2(c)) for c in chunk_shape)
1455+
bits = tuple((c - 1).bit_length() for c in chunk_shape)
14561456
max_coords_bits = max(bits)
14571457
input_bit = 0
14581458
input_value = z
@@ -1467,21 +1467,110 @@ def decode_morton(z: int, chunk_shape: tuple[int, ...]) -> tuple[int, ...]:
14671467
return tuple(out)
14681468

14691469

1470-
@lru_cache
1471-
def _morton_order(chunk_shape: tuple[int, ...]) -> tuple[tuple[int, ...], ...]:
1470+
def decode_morton_vectorized(
1471+
z: npt.NDArray[np.intp], chunk_shape: tuple[int, ...]
1472+
) -> npt.NDArray[np.intp]:
1473+
"""Vectorized Morton code decoding for multiple z values.
1474+
1475+
Parameters
1476+
----------
1477+
z : ndarray
1478+
1D array of Morton codes to decode.
1479+
chunk_shape : tuple of int
1480+
Shape defining the coordinate space.
1481+
1482+
Returns
1483+
-------
1484+
ndarray
1485+
2D array of shape (len(z), len(chunk_shape)) containing decoded coordinates.
1486+
"""
1487+
n_dims = len(chunk_shape)
1488+
bits = tuple((c - 1).bit_length() for c in chunk_shape)
1489+
1490+
max_coords_bits = max(bits) if bits else 0
1491+
out = np.zeros((len(z), n_dims), dtype=np.intp)
1492+
1493+
input_bit = 0
1494+
for coord_bit in range(max_coords_bits):
1495+
for dim in range(n_dims):
1496+
if coord_bit < bits[dim]:
1497+
# Extract bit at position input_bit from all z values
1498+
bit_values = (z >> input_bit) & 1
1499+
# Place bit at coord_bit position in dimension dim
1500+
out[:, dim] |= bit_values << coord_bit
1501+
input_bit += 1
1502+
1503+
return out
1504+
1505+
1506+
@lru_cache(maxsize=16)
1507+
def _morton_order(chunk_shape: tuple[int, ...]) -> npt.NDArray[np.intp]:
14721508
n_total = product(chunk_shape)
1473-
order: list[tuple[int, ...]] = []
1474-
i = 0
1475-
while len(order) < n_total:
1509+
n_dims = len(chunk_shape)
1510+
if n_total == 0:
1511+
out = np.empty((0, n_dims), dtype=np.intp)
1512+
out.flags.writeable = False
1513+
return out
1514+
1515+
# Optimization: Remove singleton dimensions to enable magic number usage
1516+
# for shapes like (1,1,32,32,32). Compute Morton on squeezed shape, then expand.
1517+
singleton_dims = tuple(i for i, s in enumerate(chunk_shape) if s == 1)
1518+
if singleton_dims:
1519+
squeezed_shape = tuple(s for s in chunk_shape if s != 1)
1520+
if squeezed_shape:
1521+
# Compute Morton order on squeezed shape, then expand singleton dims (always 0)
1522+
squeezed_order = np.asarray(_morton_order(squeezed_shape))
1523+
out = np.zeros((n_total, n_dims), dtype=np.intp)
1524+
squeezed_col = 0
1525+
for full_col in range(n_dims):
1526+
if chunk_shape[full_col] != 1:
1527+
out[:, full_col] = squeezed_order[:, squeezed_col]
1528+
squeezed_col += 1
1529+
else:
1530+
# All dimensions are singletons, just return the single point
1531+
out = np.zeros((1, n_dims), dtype=np.intp)
1532+
out.flags.writeable = False
1533+
return out
1534+
1535+
# Find the largest power-of-2 hypercube that fits within chunk_shape.
1536+
# Within this hypercube, Morton codes are guaranteed to be in bounds.
1537+
min_dim = min(chunk_shape)
1538+
if min_dim >= 1:
1539+
power = min_dim.bit_length() - 1 # floor(log2(min_dim))
1540+
hypercube_size = 1 << power # 2^power
1541+
n_hypercube = hypercube_size**n_dims
1542+
else:
1543+
n_hypercube = 0
1544+
1545+
# Within the hypercube, no bounds checking needed - use vectorized decoding
1546+
if n_hypercube > 0:
1547+
z_values = np.arange(n_hypercube, dtype=np.intp)
1548+
order: npt.NDArray[np.intp] = decode_morton_vectorized(z_values, chunk_shape)
1549+
else:
1550+
order = np.empty((0, n_dims), dtype=np.intp)
1551+
1552+
# For remaining elements outside the hypercube, bounds checking is needed
1553+
remaining: list[tuple[int, ...]] = []
1554+
i = n_hypercube
1555+
while len(order) + len(remaining) < n_total:
14761556
m = decode_morton(i, chunk_shape)
14771557
if all(x < y for x, y in zip(m, chunk_shape, strict=False)):
1478-
order.append(m)
1558+
remaining.append(m)
14791559
i += 1
1480-
return tuple(order)
1560+
1561+
if remaining:
1562+
order = np.vstack([order, np.array(remaining, dtype=np.intp)])
1563+
order.flags.writeable = False
1564+
return order
1565+
1566+
1567+
@lru_cache(maxsize=16)
1568+
def _morton_order_keys(chunk_shape: tuple[int, ...]) -> tuple[tuple[int, ...], ...]:
1569+
return tuple(tuple(int(x) for x in row) for row in _morton_order(chunk_shape))
14811570

14821571

14831572
def morton_order_iter(chunk_shape: tuple[int, ...]) -> Iterator[tuple[int, ...]]:
1484-
return iter(_morton_order(tuple(chunk_shape)))
1573+
return iter(_morton_order_keys(tuple(chunk_shape)))
14851574

14861575

14871576
def c_order_iter(chunks_per_shard: tuple[int, ...]) -> Iterator[tuple[int, ...]]:

src/zarr/testing/stateful.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -340,13 +340,13 @@ def delete_array_using_del(self, data: DataObject) -> None:
340340
self.all_arrays.remove(array_path)
341341

342342
@precondition(lambda self: self.store.supports_deletes)
343-
@precondition(lambda self: len(self.all_groups) >= 2) # fixme don't delete root
343+
@precondition(lambda self: bool(self.all_groups))
344344
@rule(data=st.data())
345345
def delete_group_using_del(self, data: DataObject) -> None:
346-
# ensure that we don't include the root group in the list of member names that we try
347-
# to delete
348-
member_names = tuple(filter(lambda v: "/" in v, sorted(self.all_groups)))
349-
group_path = data.draw(st.sampled_from(member_names), label="Group deletion target")
346+
group_path = data.draw(
347+
st.sampled_from(sorted(self.all_groups)),
348+
label="Group deletion target",
349+
)
350350
prefix, group_name = split_prefix_name(group_path)
351351
note(f"Deleting group '{group_path=!r}', {prefix=!r}, {group_name=!r} using delete")
352352
members = zarr.open_group(store=self.model, path=group_path).members(max_depth=None)
@@ -359,9 +359,7 @@ def delete_group_using_del(self, data: DataObject) -> None:
359359
group = zarr.open_group(store=store, path=prefix)
360360
group[group_name] # check that it exists
361361
del group[group_name]
362-
if group_path != "/":
363-
# The root group is always present
364-
self.all_groups.remove(group_path)
362+
self.all_groups.remove(group_path)
365363

366364
# # --------------- assertions -----------------
367365
# def check_group_arrays(self, group):

0 commit comments

Comments
 (0)