Skip to content

Commit 2f9b0b3

Browse files
mkitticlauded-v-b
authored
perf: Vectorize get_chunk_slice for faster sharded writes (#3713)
* perf: Skip bounds check for initial elements in 2^n hypercube * lint:Use a list comprehension rather than a for loop * pref:Add decode_morton_vectorized * perf:Replace math.log2() with bit_length() * perf:Use magic numbers for 2D and 3D * perf:Add 4D Morton magic numbers * perf:Add Morton magic numbers for 5D * perf:Remove singleton dimensions to reduce ndims * Add changes * fix:Address type annotation and linting issues * perf:Remove magic number functions * test:Add power of 2 sharding indexing tests * test: Add Morton order benchmarks with cache clearing Add benchmarks that clear the _morton_order LRU cache before each iteration to measure the full Morton computation cost: - test_sharded_morton_indexing: 512-4096 chunks per shard - test_sharded_morton_indexing_large: 32768 chunks per shard Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * fix:Bound LRU cache of _morton_order to 16 * test:Add a single chunk test for a large shard * test:Add indexing benchmarks for writing * tests:Add single chunk write test for sharding * perf: Vectorize get_chunk_slice for faster sharded writes Add vectorized methods to _ShardIndex and _ShardReader for batch chunk slice lookups, reducing per-chunk function call overhead when writing to shards. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * refactor: Return ndarray from _morton_order, simplify to_dict_vectorized _morton_order now returns a read-only npt.NDArray[np.intp] (annotated as Iterable[Sequence[int]]) instead of a tuple of tuples, eliminating the intermediate list-of-tuples allocation. morton_order_iter converts rows to tuples on the fly. to_dict_vectorized no longer requires a redundant chunk_coords_tuples argument; tuple conversion happens inline during dict population. get_chunk_slices_vectorized accepts any integer array dtype (npt.NDArray[np.integer[Any]]) and casts to uint64 internally. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * perf: Cache tuple keys separately from ndarray in _morton_order_keys Add _morton_order_keys() as a second lru_cache that converts the ndarray returned by _morton_order into a tuple of tuples. This restores cached access to hashable chunk coordinate keys without reverting to the old dual-argument interface. morton_order_iter now uses _morton_order_keys, and to_dict_vectorized derives its keys from _morton_order_keys internally using the shard index shape, keeping the call site single-argument. Result: test_sharded_morton_write_single_chunk[(32,32,32)] improves from ~33ms to ~7ms (~5x speedup over prior to this PR's changes). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * tests: Clear _morton_order_keys cache alongside _morton_order in benchmarks All benchmark functions that call _morton_order.cache_clear() now also call _morton_order_keys.cache_clear() to ensure both caches are reset before each benchmark iteration. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * refactor: Use npt.NDArray[np.intp] as return type for _morton_order More precise than Iterable[Sequence[int]] and accurately reflects the actual return value. Remove the now-unused Iterable import. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com> Co-authored-by: Davis Bennett <davis.v.bennett@gmail.com>
1 parent f8b3d38 commit 2f9b0b3

File tree

4 files changed

+114
-35
lines changed

4 files changed

+114
-35
lines changed

changes/3713.misc.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Vectorize get_chunk_slice for faster sharded array writes.

src/zarr/codecs/sharding.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
from zarr.core.indexing import (
4747
BasicIndexer,
4848
SelectorTuple,
49+
_morton_order,
50+
_morton_order_keys,
4951
c_order_iter,
5052
get_indexer,
5153
morton_order_iter,
@@ -144,6 +146,45 @@ def get_chunk_slice(self, chunk_coords: tuple[int, ...]) -> tuple[int, int] | No
144146
else:
145147
return (int(chunk_start), int(chunk_start + chunk_len))
146148

149+
def get_chunk_slices_vectorized(
150+
self, chunk_coords_array: npt.NDArray[np.integer[Any]]
151+
) -> tuple[npt.NDArray[np.uint64], npt.NDArray[np.uint64], npt.NDArray[np.bool_]]:
152+
"""Get chunk slices for multiple coordinates at once.
153+
154+
Parameters
155+
----------
156+
chunk_coords_array : ndarray of shape (n_chunks, n_dims)
157+
Array of chunk coordinates to look up.
158+
159+
Returns
160+
-------
161+
starts : ndarray of shape (n_chunks,)
162+
Start byte positions for each chunk.
163+
ends : ndarray of shape (n_chunks,)
164+
End byte positions for each chunk.
165+
valid : ndarray of shape (n_chunks,)
166+
Boolean mask indicating which chunks are non-empty.
167+
"""
168+
# Localize coordinates via modulo (vectorized)
169+
shard_shape = np.array(self.offsets_and_lengths.shape[:-1], dtype=np.uint64)
170+
localized = chunk_coords_array.astype(np.uint64) % shard_shape
171+
172+
# Build index tuple for advanced indexing
173+
index_tuple = tuple(localized[:, i] for i in range(localized.shape[1]))
174+
175+
# Fetch all offsets and lengths at once
176+
offsets_and_lengths = self.offsets_and_lengths[index_tuple]
177+
starts = offsets_and_lengths[:, 0]
178+
lengths = offsets_and_lengths[:, 1]
179+
180+
# Check for valid (non-empty) chunks
181+
valid = starts != MAX_UINT_64
182+
183+
# Compute end positions
184+
ends = starts + lengths
185+
186+
return starts, ends, valid
187+
147188
def set_chunk_slice(self, chunk_coords: tuple[int, ...], chunk_slice: slice | None) -> None:
148189
localized_chunk = self._localize_chunk(chunk_coords)
149190
if chunk_slice is None:
@@ -225,6 +266,34 @@ def __len__(self) -> int:
225266
def __iter__(self) -> Iterator[tuple[int, ...]]:
226267
return c_order_iter(self.index.offsets_and_lengths.shape[:-1])
227268

269+
def to_dict_vectorized(
270+
self,
271+
chunk_coords_array: npt.NDArray[np.integer[Any]],
272+
) -> dict[tuple[int, ...], Buffer | None]:
273+
"""Build a dict of chunk coordinates to buffers using vectorized lookup.
274+
275+
Parameters
276+
----------
277+
chunk_coords_array : ndarray of shape (n_chunks, n_dims)
278+
Array of chunk coordinates for vectorized index lookup.
279+
280+
Returns
281+
-------
282+
dict mapping chunk coordinate tuples to Buffer or None
283+
"""
284+
starts, ends, valid = self.index.get_chunk_slices_vectorized(chunk_coords_array)
285+
chunks_per_shard = tuple(self.index.offsets_and_lengths.shape[:-1])
286+
chunk_coords_keys = _morton_order_keys(chunks_per_shard)
287+
288+
result: dict[tuple[int, ...], Buffer | None] = {}
289+
for i, coords in enumerate(chunk_coords_keys):
290+
if valid[i]:
291+
result[coords] = self.buf[int(starts[i]) : int(ends[i])]
292+
else:
293+
result[coords] = None
294+
295+
return result
296+
228297

229298
@dataclass(frozen=True)
230299
class ShardingCodec(
@@ -511,7 +580,8 @@ async def _encode_partial_single(
511580
chunks_per_shard=chunks_per_shard,
512581
)
513582
shard_reader = shard_reader or _ShardReader.create_empty(chunks_per_shard)
514-
shard_dict = {k: shard_reader.get(k) for k in morton_order_iter(chunks_per_shard)}
583+
# Use vectorized lookup for better performance
584+
shard_dict = shard_reader.to_dict_vectorized(np.asarray(_morton_order(chunks_per_shard)))
515585

516586
indexer = list(
517587
get_indexer(

src/zarr/core/indexing.py

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1504,37 +1504,33 @@ def decode_morton_vectorized(
15041504

15051505

15061506
@lru_cache(maxsize=16)
1507-
def _morton_order(chunk_shape: tuple[int, ...]) -> tuple[tuple[int, ...], ...]:
1507+
def _morton_order(chunk_shape: tuple[int, ...]) -> npt.NDArray[np.intp]:
15081508
n_total = product(chunk_shape)
1509+
n_dims = len(chunk_shape)
15091510
if n_total == 0:
1510-
return ()
1511+
out = np.empty((0, n_dims), dtype=np.intp)
1512+
out.flags.writeable = False
1513+
return out
15111514

15121515
# Optimization: Remove singleton dimensions to enable magic number usage
15131516
# for shapes like (1,1,32,32,32). Compute Morton on squeezed shape, then expand.
15141517
singleton_dims = tuple(i for i, s in enumerate(chunk_shape) if s == 1)
15151518
if singleton_dims:
15161519
squeezed_shape = tuple(s for s in chunk_shape if s != 1)
15171520
if squeezed_shape:
1518-
# Compute Morton order on squeezed shape
1519-
squeezed_order = _morton_order(squeezed_shape)
1520-
# Expand coordinates to include singleton dimensions (always 0)
1521-
expanded: list[tuple[int, ...]] = []
1522-
for coord in squeezed_order:
1523-
full_coord: list[int] = []
1524-
squeezed_idx = 0
1525-
for i in range(len(chunk_shape)):
1526-
if chunk_shape[i] == 1:
1527-
full_coord.append(0)
1528-
else:
1529-
full_coord.append(coord[squeezed_idx])
1530-
squeezed_idx += 1
1531-
expanded.append(tuple(full_coord))
1532-
return tuple(expanded)
1521+
# Compute Morton order on squeezed shape, then expand singleton dims (always 0)
1522+
squeezed_order = np.asarray(_morton_order(squeezed_shape))
1523+
out = np.zeros((n_total, n_dims), dtype=np.intp)
1524+
squeezed_col = 0
1525+
for full_col in range(n_dims):
1526+
if chunk_shape[full_col] != 1:
1527+
out[:, full_col] = squeezed_order[:, squeezed_col]
1528+
squeezed_col += 1
15331529
else:
15341530
# All dimensions are singletons, just return the single point
1535-
return ((0,) * len(chunk_shape),)
1536-
1537-
n_dims = len(chunk_shape)
1531+
out = np.zeros((1, n_dims), dtype=np.intp)
1532+
out.flags.writeable = False
1533+
return out
15381534

15391535
# Find the largest power-of-2 hypercube that fits within chunk_shape.
15401536
# Within this hypercube, Morton codes are guaranteed to be in bounds.
@@ -1547,27 +1543,34 @@ def _morton_order(chunk_shape: tuple[int, ...]) -> tuple[tuple[int, ...], ...]:
15471543
n_hypercube = 0
15481544

15491545
# Within the hypercube, no bounds checking needed - use vectorized decoding
1550-
order: list[tuple[int, ...]]
15511546
if n_hypercube > 0:
15521547
z_values = np.arange(n_hypercube, dtype=np.intp)
1553-
hypercube_coords = decode_morton_vectorized(z_values, chunk_shape)
1554-
order = [tuple(row) for row in hypercube_coords]
1548+
order: npt.NDArray[np.intp] = decode_morton_vectorized(z_values, chunk_shape)
15551549
else:
1556-
order = []
1550+
order = np.empty((0, n_dims), dtype=np.intp)
15571551

1558-
# For remaining elements, bounds checking is needed
1552+
# For remaining elements outside the hypercube, bounds checking is needed
1553+
remaining: list[tuple[int, ...]] = []
15591554
i = n_hypercube
1560-
while len(order) < n_total:
1555+
while len(order) + len(remaining) < n_total:
15611556
m = decode_morton(i, chunk_shape)
15621557
if all(x < y for x, y in zip(m, chunk_shape, strict=False)):
1563-
order.append(m)
1558+
remaining.append(m)
15641559
i += 1
15651560

1566-
return tuple(order)
1561+
if remaining:
1562+
order = np.vstack([order, np.array(remaining, dtype=np.intp)])
1563+
order.flags.writeable = False
1564+
return order
1565+
1566+
1567+
@lru_cache(maxsize=16)
1568+
def _morton_order_keys(chunk_shape: tuple[int, ...]) -> tuple[tuple[int, ...], ...]:
1569+
return tuple(tuple(int(x) for x in row) for row in _morton_order(chunk_shape))
15671570

15681571

15691572
def morton_order_iter(chunk_shape: tuple[int, ...]) -> Iterator[tuple[int, ...]]:
1570-
return iter(_morton_order(tuple(chunk_shape)))
1573+
return iter(_morton_order_keys(tuple(chunk_shape)))
15711574

15721575

15731576
def c_order_iter(chunks_per_shard: tuple[int, ...]) -> Iterator[tuple[int, ...]]:

tests/benchmarks/test_indexing.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def test_sharded_morton_indexing(
7474
The Morton order cache is cleared before each iteration to measure the
7575
full computation cost.
7676
"""
77-
from zarr.core.indexing import _morton_order
77+
from zarr.core.indexing import _morton_order, _morton_order_keys
7878

7979
# Create array where each shard contains many small chunks
8080
# e.g., shards=(32,32,32) with chunks=(2,2,2) means 16x16x16 = 4096 chunks per shard
@@ -98,6 +98,7 @@ def test_sharded_morton_indexing(
9898

9999
def read_with_cache_clear() -> None:
100100
_morton_order.cache_clear()
101+
_morton_order_keys.cache_clear()
101102
getitem(data, indexer)
102103

103104
benchmark(read_with_cache_clear)
@@ -122,7 +123,7 @@ def test_sharded_morton_indexing_large(
122123
the Morton order computation a more significant portion of total time.
123124
The Morton order cache is cleared before each iteration.
124125
"""
125-
from zarr.core.indexing import _morton_order
126+
from zarr.core.indexing import _morton_order, _morton_order_keys
126127

127128
# 1x1x1 chunks means chunks_per_shard equals shard shape
128129
shape = tuple(s * 2 for s in shards) # 2 shards per dimension
@@ -145,6 +146,7 @@ def test_sharded_morton_indexing_large(
145146

146147
def read_with_cache_clear() -> None:
147148
_morton_order.cache_clear()
149+
_morton_order_keys.cache_clear()
148150
getitem(data, indexer)
149151

150152
benchmark(read_with_cache_clear)
@@ -164,7 +166,7 @@ def test_sharded_morton_single_chunk(
164166
computing the full Morton order, making the optimization impact clear.
165167
The Morton order cache is cleared before each iteration.
166168
"""
167-
from zarr.core.indexing import _morton_order
169+
from zarr.core.indexing import _morton_order, _morton_order_keys
168170

169171
# 1x1x1 chunks means chunks_per_shard equals shard shape
170172
shape = tuple(s * 2 for s in shards) # 2 shards per dimension
@@ -187,6 +189,7 @@ def test_sharded_morton_single_chunk(
187189

188190
def read_with_cache_clear() -> None:
189191
_morton_order.cache_clear()
192+
_morton_order_keys.cache_clear()
190193
getitem(data, indexer)
191194

192195
benchmark(read_with_cache_clear)
@@ -211,10 +214,11 @@ def test_morton_order_iter(
211214
optimization impact without array read/write overhead.
212215
The cache is cleared before each iteration.
213216
"""
214-
from zarr.core.indexing import _morton_order, morton_order_iter
217+
from zarr.core.indexing import _morton_order, _morton_order_keys, morton_order_iter
215218

216219
def compute_morton_order() -> None:
217220
_morton_order.cache_clear()
221+
_morton_order_keys.cache_clear()
218222
# Consume the iterator to force computation
219223
list(morton_order_iter(shape))
220224

@@ -239,7 +243,7 @@ def test_sharded_morton_write_single_chunk(
239243
"""
240244
import numpy as np
241245

242-
from zarr.core.indexing import _morton_order
246+
from zarr.core.indexing import _morton_order, _morton_order_keys
243247

244248
# 1x1x1 chunks means chunks_per_shard equals shard shape
245249
shape = tuple(s * 2 for s in shards) # 2 shards per dimension
@@ -262,6 +266,7 @@ def test_sharded_morton_write_single_chunk(
262266

263267
def write_with_cache_clear() -> None:
264268
_morton_order.cache_clear()
269+
_morton_order_keys.cache_clear()
265270
data[indexer] = write_data
266271

267272
benchmark(write_with_cache_clear)

0 commit comments

Comments
 (0)