Skip to content

Commit b3f368f

Browse files
authored
Merge branch 'main' into feat/memory-store-registry
2 parents 8d490af + 879e1ce commit b3f368f

6 files changed

Lines changed: 153 additions & 74 deletions

File tree

changes/3713.misc.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Vectorize get_chunk_slice for faster sharded array writes.

changes/3717.misc.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add benchmarks for Morton order computation with non-power-of-2 and near-miss shard shapes, covering both pure computation and end-to-end read/write performance.

src/zarr/codecs/sharding.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
from zarr.core.indexing import (
4747
BasicIndexer,
4848
SelectorTuple,
49+
_morton_order,
50+
_morton_order_keys,
4951
c_order_iter,
5052
get_indexer,
5153
morton_order_iter,
@@ -144,6 +146,45 @@ def get_chunk_slice(self, chunk_coords: tuple[int, ...]) -> tuple[int, int] | No
144146
else:
145147
return (int(chunk_start), int(chunk_start + chunk_len))
146148

149+
def get_chunk_slices_vectorized(
150+
self, chunk_coords_array: npt.NDArray[np.integer[Any]]
151+
) -> tuple[npt.NDArray[np.uint64], npt.NDArray[np.uint64], npt.NDArray[np.bool_]]:
152+
"""Get chunk slices for multiple coordinates at once.
153+
154+
Parameters
155+
----------
156+
chunk_coords_array : ndarray of shape (n_chunks, n_dims)
157+
Array of chunk coordinates to look up.
158+
159+
Returns
160+
-------
161+
starts : ndarray of shape (n_chunks,)
162+
Start byte positions for each chunk.
163+
ends : ndarray of shape (n_chunks,)
164+
End byte positions for each chunk.
165+
valid : ndarray of shape (n_chunks,)
166+
Boolean mask indicating which chunks are non-empty.
167+
"""
168+
# Localize coordinates via modulo (vectorized)
169+
shard_shape = np.array(self.offsets_and_lengths.shape[:-1], dtype=np.uint64)
170+
localized = chunk_coords_array.astype(np.uint64) % shard_shape
171+
172+
# Build index tuple for advanced indexing
173+
index_tuple = tuple(localized[:, i] for i in range(localized.shape[1]))
174+
175+
# Fetch all offsets and lengths at once
176+
offsets_and_lengths = self.offsets_and_lengths[index_tuple]
177+
starts = offsets_and_lengths[:, 0]
178+
lengths = offsets_and_lengths[:, 1]
179+
180+
# Check for valid (non-empty) chunks
181+
valid = starts != MAX_UINT_64
182+
183+
# Compute end positions
184+
ends = starts + lengths
185+
186+
return starts, ends, valid
187+
147188
def set_chunk_slice(self, chunk_coords: tuple[int, ...], chunk_slice: slice | None) -> None:
148189
localized_chunk = self._localize_chunk(chunk_coords)
149190
if chunk_slice is None:
@@ -225,6 +266,34 @@ def __len__(self) -> int:
225266
def __iter__(self) -> Iterator[tuple[int, ...]]:
226267
return c_order_iter(self.index.offsets_and_lengths.shape[:-1])
227268

269+
def to_dict_vectorized(
270+
self,
271+
chunk_coords_array: npt.NDArray[np.integer[Any]],
272+
) -> dict[tuple[int, ...], Buffer | None]:
273+
"""Build a dict of chunk coordinates to buffers using vectorized lookup.
274+
275+
Parameters
276+
----------
277+
chunk_coords_array : ndarray of shape (n_chunks, n_dims)
278+
Array of chunk coordinates for vectorized index lookup.
279+
280+
Returns
281+
-------
282+
dict mapping chunk coordinate tuples to Buffer or None
283+
"""
284+
starts, ends, valid = self.index.get_chunk_slices_vectorized(chunk_coords_array)
285+
chunks_per_shard = tuple(self.index.offsets_and_lengths.shape[:-1])
286+
chunk_coords_keys = _morton_order_keys(chunks_per_shard)
287+
288+
result: dict[tuple[int, ...], Buffer | None] = {}
289+
for i, coords in enumerate(chunk_coords_keys):
290+
if valid[i]:
291+
result[coords] = self.buf[int(starts[i]) : int(ends[i])]
292+
else:
293+
result[coords] = None
294+
295+
return result
296+
228297

229298
@dataclass(frozen=True)
230299
class ShardingCodec(
@@ -511,7 +580,8 @@ async def _encode_partial_single(
511580
chunks_per_shard=chunks_per_shard,
512581
)
513582
shard_reader = shard_reader or _ShardReader.create_empty(chunks_per_shard)
514-
shard_dict = {k: shard_reader.get(k) for k in morton_order_iter(chunks_per_shard)}
583+
# Use vectorized lookup for better performance
584+
shard_dict = shard_reader.to_dict_vectorized(np.asarray(_morton_order(chunks_per_shard)))
515585

516586
indexer = list(
517587
get_indexer(

src/zarr/core/indexing.py

Lines changed: 53 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1504,70 +1504,67 @@ def decode_morton_vectorized(
15041504

15051505

15061506
@lru_cache(maxsize=16)
1507-
def _morton_order(chunk_shape: tuple[int, ...]) -> tuple[tuple[int, ...], ...]:
1507+
def _morton_order(chunk_shape: tuple[int, ...]) -> npt.NDArray[np.intp]:
15081508
n_total = product(chunk_shape)
1509-
if n_total == 0:
1510-
return ()
1511-
1512-
# Optimization: Remove singleton dimensions to enable magic number usage
1513-
# for shapes like (1,1,32,32,32). Compute Morton on squeezed shape, then expand.
1514-
singleton_dims = tuple(i for i, s in enumerate(chunk_shape) if s == 1)
1515-
if singleton_dims:
1516-
squeezed_shape = tuple(s for s in chunk_shape if s != 1)
1517-
if squeezed_shape:
1518-
# Compute Morton order on squeezed shape
1519-
squeezed_order = _morton_order(squeezed_shape)
1520-
# Expand coordinates to include singleton dimensions (always 0)
1521-
expanded: list[tuple[int, ...]] = []
1522-
for coord in squeezed_order:
1523-
full_coord: list[int] = []
1524-
squeezed_idx = 0
1525-
for i in range(len(chunk_shape)):
1526-
if chunk_shape[i] == 1:
1527-
full_coord.append(0)
1528-
else:
1529-
full_coord.append(coord[squeezed_idx])
1530-
squeezed_idx += 1
1531-
expanded.append(tuple(full_coord))
1532-
return tuple(expanded)
1533-
else:
1534-
# All dimensions are singletons, just return the single point
1535-
return ((0,) * len(chunk_shape),)
1536-
15371509
n_dims = len(chunk_shape)
1538-
1539-
# Find the largest power-of-2 hypercube that fits within chunk_shape.
1540-
# Within this hypercube, Morton codes are guaranteed to be in bounds.
1541-
min_dim = min(chunk_shape)
1542-
if min_dim >= 1:
1543-
power = min_dim.bit_length() - 1 # floor(log2(min_dim))
1544-
hypercube_size = 1 << power # 2^power
1545-
n_hypercube = hypercube_size**n_dims
1546-
else:
1547-
n_hypercube = 0
1548-
1549-
# Within the hypercube, no bounds checking needed - use vectorized decoding
1550-
order: list[tuple[int, ...]]
1551-
if n_hypercube > 0:
1552-
z_values = np.arange(n_hypercube, dtype=np.intp)
1553-
hypercube_coords = decode_morton_vectorized(z_values, chunk_shape)
1554-
order = [tuple(row) for row in hypercube_coords]
1510+
if n_total == 0:
1511+
out = np.empty((0, n_dims), dtype=np.intp)
1512+
out.flags.writeable = False
1513+
return out
1514+
1515+
# Ceiling hypercube: smallest power-of-2 hypercube whose Morton codes span
1516+
# all valid coordinates in chunk_shape. (c-1).bit_length() gives the number
1517+
# of bits needed to index c values (0 for singleton dims). n_z = 2**total_bits
1518+
# is the size of this hypercube.
1519+
total_bits = sum((c - 1).bit_length() for c in chunk_shape)
1520+
n_z = 1 << total_bits if total_bits > 0 else 1
1521+
1522+
# Decode all Morton codes in the ceiling hypercube, then filter to valid coords.
1523+
# This is fully vectorized. For shapes with n_z >> n_total (e.g. (33,33,33):
1524+
# n_z=262144, n_total=35937), consider the argsort strategy below.
1525+
order: npt.NDArray[np.intp]
1526+
if n_z <= 4 * n_total:
1527+
# Ceiling strategy: decode all n_z codes vectorized, filter in-bounds.
1528+
# Works well when the overgeneration ratio n_z/n_total is small (≤4).
1529+
z_values = np.arange(n_z, dtype=np.intp)
1530+
all_coords = decode_morton_vectorized(z_values, chunk_shape)
1531+
shape_arr = np.array(chunk_shape, dtype=np.intp)
1532+
valid_mask = np.all(all_coords < shape_arr, axis=1)
1533+
order = all_coords[valid_mask]
15551534
else:
1556-
order = []
1535+
# Argsort strategy: enumerate all n_total valid coordinates directly,
1536+
# encode each to a Morton code, then sort by code. Avoids the 8x or
1537+
# larger overgeneration penalty for near-miss shapes like (33,33,33).
1538+
# Cost: O(n_total * bits) encode + O(n_total log n_total) sort,
1539+
# vs O(n_z * bits) = O(8 * n_total * bits) for ceiling.
1540+
grids = np.meshgrid(*[np.arange(c, dtype=np.intp) for c in chunk_shape], indexing="ij")
1541+
all_coords = np.stack([g.ravel() for g in grids], axis=1)
1542+
1543+
# Encode all coordinates to Morton codes (vectorized).
1544+
bits_per_dim = tuple((c - 1).bit_length() for c in chunk_shape)
1545+
max_coord_bits = max(bits_per_dim)
1546+
z_codes = np.zeros(n_total, dtype=np.intp)
1547+
output_bit = 0
1548+
for coord_bit in range(max_coord_bits):
1549+
for dim in range(n_dims):
1550+
if coord_bit < bits_per_dim[dim]:
1551+
z_codes |= ((all_coords[:, dim] >> coord_bit) & 1) << output_bit
1552+
output_bit += 1
1553+
1554+
sort_idx: npt.NDArray[np.intp] = np.argsort(z_codes, kind="stable")
1555+
order = np.asarray(all_coords[sort_idx], dtype=np.intp)
1556+
1557+
order.flags.writeable = False
1558+
return order
15571559

1558-
# For remaining elements, bounds checking is needed
1559-
i = n_hypercube
1560-
while len(order) < n_total:
1561-
m = decode_morton(i, chunk_shape)
1562-
if all(x < y for x, y in zip(m, chunk_shape, strict=False)):
1563-
order.append(m)
1564-
i += 1
15651560

1566-
return tuple(order)
1561+
@lru_cache(maxsize=16)
1562+
def _morton_order_keys(chunk_shape: tuple[int, ...]) -> tuple[tuple[int, ...], ...]:
1563+
return tuple(tuple(int(x) for x in row) for row in _morton_order(chunk_shape))
15671564

15681565

15691566
def morton_order_iter(chunk_shape: tuple[int, ...]) -> Iterator[tuple[int, ...]]:
1570-
return iter(_morton_order(tuple(chunk_shape)))
1567+
return iter(_morton_order_keys(tuple(chunk_shape)))
15711568

15721569

15731570
def c_order_iter(chunks_per_shard: tuple[int, ...]) -> Iterator[tuple[int, ...]]:

src/zarr/testing/stateful.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -340,13 +340,13 @@ def delete_array_using_del(self, data: DataObject) -> None:
340340
self.all_arrays.remove(array_path)
341341

342342
@precondition(lambda self: self.store.supports_deletes)
343-
@precondition(lambda self: len(self.all_groups) >= 2) # fixme don't delete root
343+
@precondition(lambda self: bool(self.all_groups))
344344
@rule(data=st.data())
345345
def delete_group_using_del(self, data: DataObject) -> None:
346-
# ensure that we don't include the root group in the list of member names that we try
347-
# to delete
348-
member_names = tuple(filter(lambda v: "/" in v, sorted(self.all_groups)))
349-
group_path = data.draw(st.sampled_from(member_names), label="Group deletion target")
346+
group_path = data.draw(
347+
st.sampled_from(sorted(self.all_groups)),
348+
label="Group deletion target",
349+
)
350350
prefix, group_name = split_prefix_name(group_path)
351351
note(f"Deleting group '{group_path=!r}', {prefix=!r}, {group_name=!r} using delete")
352352
members = zarr.open_group(store=self.model, path=group_path).members(max_depth=None)
@@ -359,9 +359,7 @@ def delete_group_using_del(self, data: DataObject) -> None:
359359
group = zarr.open_group(store=store, path=prefix)
360360
group[group_name] # check that it exists
361361
del group[group_name]
362-
if group_path != "/":
363-
# The root group is always present
364-
self.all_groups.remove(group_path)
362+
self.all_groups.remove(group_path)
365363

366364
# # --------------- assertions -----------------
367365
# def check_group_arrays(self, group):

tests/benchmarks/test_indexing.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def test_sharded_morton_indexing(
7474
The Morton order cache is cleared before each iteration to measure the
7575
full computation cost.
7676
"""
77-
from zarr.core.indexing import _morton_order
77+
from zarr.core.indexing import _morton_order, _morton_order_keys
7878

7979
# Create array where each shard contains many small chunks
8080
# e.g., shards=(32,32,32) with chunks=(2,2,2) means 16x16x16 = 4096 chunks per shard
@@ -98,14 +98,18 @@ def test_sharded_morton_indexing(
9898

9999
def read_with_cache_clear() -> None:
100100
_morton_order.cache_clear()
101+
_morton_order_keys.cache_clear()
101102
getitem(data, indexer)
102103

103104
benchmark(read_with_cache_clear)
104105

105106

106107
# Benchmark with larger chunks_per_shard to make Morton order impact more visible
107108
large_morton_shards = (
108-
(32,) * 3, # With 1x1x1 chunks: 32x32x32 = 32768 chunks per shard
109+
(32,) * 3, # With 1x1x1 chunks: 32x32x32 = 32768 chunks per shard (power-of-2)
110+
(30,) * 3, # With 1x1x1 chunks: 30x30x30 = 27000 chunks per shard (non-power-of-2)
111+
(33,)
112+
* 3, # With 1x1x1 chunks: 33x33x33 = 35937 chunks per shard (near-miss: just above power-of-2)
109113
)
110114

111115

@@ -122,7 +126,7 @@ def test_sharded_morton_indexing_large(
122126
the Morton order computation a more significant portion of total time.
123127
The Morton order cache is cleared before each iteration.
124128
"""
125-
from zarr.core.indexing import _morton_order
129+
from zarr.core.indexing import _morton_order, _morton_order_keys
126130

127131
# 1x1x1 chunks means chunks_per_shard equals shard shape
128132
shape = tuple(s * 2 for s in shards) # 2 shards per dimension
@@ -145,6 +149,7 @@ def test_sharded_morton_indexing_large(
145149

146150
def read_with_cache_clear() -> None:
147151
_morton_order.cache_clear()
152+
_morton_order_keys.cache_clear()
148153
getitem(data, indexer)
149154

150155
benchmark(read_with_cache_clear)
@@ -164,7 +169,7 @@ def test_sharded_morton_single_chunk(
164169
computing the full Morton order, making the optimization impact clear.
165170
The Morton order cache is cleared before each iteration.
166171
"""
167-
from zarr.core.indexing import _morton_order
172+
from zarr.core.indexing import _morton_order, _morton_order_keys
168173

169174
# 1x1x1 chunks means chunks_per_shard equals shard shape
170175
shape = tuple(s * 2 for s in shards) # 2 shards per dimension
@@ -187,16 +192,21 @@ def test_sharded_morton_single_chunk(
187192

188193
def read_with_cache_clear() -> None:
189194
_morton_order.cache_clear()
195+
_morton_order_keys.cache_clear()
190196
getitem(data, indexer)
191197

192198
benchmark(read_with_cache_clear)
193199

194200

195201
# Benchmark for morton_order_iter directly (no I/O)
196202
morton_iter_shapes = (
197-
(8, 8, 8), # 512 elements
198-
(16, 16, 16), # 4096 elements
199-
(32, 32, 32), # 32768 elements
203+
(8, 8, 8), # 512 elements (power-of-2)
204+
(10, 10, 10), # 1000 elements (non-power-of-2)
205+
(16, 16, 16), # 4096 elements (power-of-2)
206+
(20, 20, 20), # 8000 elements (non-power-of-2)
207+
(32, 32, 32), # 32768 elements (power-of-2)
208+
(30, 30, 30), # 27000 elements (non-power-of-2)
209+
(33, 33, 33), # 35937 elements (near-miss: just above power-of-2, n_z=262144)
200210
)
201211

202212

@@ -211,10 +221,11 @@ def test_morton_order_iter(
211221
optimization impact without array read/write overhead.
212222
The cache is cleared before each iteration.
213223
"""
214-
from zarr.core.indexing import _morton_order, morton_order_iter
224+
from zarr.core.indexing import _morton_order, _morton_order_keys, morton_order_iter
215225

216226
def compute_morton_order() -> None:
217227
_morton_order.cache_clear()
228+
_morton_order_keys.cache_clear()
218229
# Consume the iterator to force computation
219230
list(morton_order_iter(shape))
220231

@@ -239,7 +250,7 @@ def test_sharded_morton_write_single_chunk(
239250
"""
240251
import numpy as np
241252

242-
from zarr.core.indexing import _morton_order
253+
from zarr.core.indexing import _morton_order, _morton_order_keys
243254

244255
# 1x1x1 chunks means chunks_per_shard equals shard shape
245256
shape = tuple(s * 2 for s in shards) # 2 shards per dimension
@@ -262,6 +273,7 @@ def test_sharded_morton_write_single_chunk(
262273

263274
def write_with_cache_clear() -> None:
264275
_morton_order.cache_clear()
276+
_morton_order_keys.cache_clear()
265277
data[indexer] = write_data
266278

267279
benchmark(write_with_cache_clear)

0 commit comments

Comments
 (0)