Skip to content

Commit 6129cd3

Browse files
mkitticlaude
andcommitted
perf: Vectorize get_chunk_slice for faster sharded writes
Add vectorized methods to _ShardIndex and _ShardReader for batch chunk slice lookups, reducing per-chunk function call overhead when writing to shards. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent a666211 commit 6129cd3

2 files changed

Lines changed: 74 additions & 1 deletion

File tree

changes/3713.misc.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Vectorize get_chunk_slice for faster sharded array writes.

src/zarr/codecs/sharding.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from zarr.core.indexing import (
4747
BasicIndexer,
4848
SelectorTuple,
49+
_morton_order,
4950
c_order_iter,
5051
get_indexer,
5152
morton_order_iter,
@@ -138,6 +139,45 @@ def get_chunk_slice(self, chunk_coords: tuple[int, ...]) -> tuple[int, int] | No
138139
else:
139140
return (int(chunk_start), int(chunk_start + chunk_len))
140141

142+
def get_chunk_slices_vectorized(
143+
self, chunk_coords_array: npt.NDArray[np.uint64]
144+
) -> tuple[npt.NDArray[np.uint64], npt.NDArray[np.uint64], npt.NDArray[np.bool_]]:
145+
"""Get chunk slices for multiple coordinates at once.
146+
147+
Parameters
148+
----------
149+
chunk_coords_array : ndarray of shape (n_chunks, n_dims)
150+
Array of chunk coordinates to look up.
151+
152+
Returns
153+
-------
154+
starts : ndarray of shape (n_chunks,)
155+
Start byte positions for each chunk.
156+
ends : ndarray of shape (n_chunks,)
157+
End byte positions for each chunk.
158+
valid : ndarray of shape (n_chunks,)
159+
Boolean mask indicating which chunks are non-empty.
160+
"""
161+
# Localize coordinates via modulo (vectorized)
162+
shard_shape = np.array(self.offsets_and_lengths.shape[:-1], dtype=np.uint64)
163+
localized = chunk_coords_array % shard_shape
164+
165+
# Build index tuple for advanced indexing
166+
index_tuple = tuple(localized[:, i] for i in range(localized.shape[1]))
167+
168+
# Fetch all offsets and lengths at once
169+
offsets_and_lengths = self.offsets_and_lengths[index_tuple]
170+
starts = offsets_and_lengths[:, 0]
171+
lengths = offsets_and_lengths[:, 1]
172+
173+
# Check for valid (non-empty) chunks
174+
valid = starts != MAX_UINT_64
175+
176+
# Compute end positions
177+
ends = starts + lengths
178+
179+
return starts, ends, valid
180+
141181
def set_chunk_slice(self, chunk_coords: tuple[int, ...], chunk_slice: slice | None) -> None:
142182
localized_chunk = self._localize_chunk(chunk_coords)
143183
if chunk_slice is None:
@@ -219,6 +259,35 @@ def __len__(self) -> int:
219259
def __iter__(self) -> Iterator[tuple[int, ...]]:
220260
return c_order_iter(self.index.offsets_and_lengths.shape[:-1])
221261

262+
def to_dict_vectorized(
263+
self,
264+
chunk_coords_array: npt.NDArray[np.uint64],
265+
chunk_coords_tuples: tuple[tuple[int, ...], ...],
266+
) -> dict[tuple[int, ...], Buffer | None]:
267+
"""Build a dict of chunk coordinates to buffers using vectorized lookup.
268+
269+
Parameters
270+
----------
271+
chunk_coords_array : ndarray of shape (n_chunks, n_dims)
272+
Array of chunk coordinates for vectorized index lookup.
273+
chunk_coords_tuples : tuple of tuples
274+
The same coordinates as tuples, used as dict keys to avoid conversion.
275+
276+
Returns
277+
-------
278+
dict mapping chunk coordinate tuples to Buffer or None
279+
"""
280+
starts, ends, valid = self.index.get_chunk_slices_vectorized(chunk_coords_array)
281+
282+
result: dict[tuple[int, ...], Buffer | None] = {}
283+
for i, coords in enumerate(chunk_coords_tuples):
284+
if valid[i]:
285+
result[coords] = self.buf[int(starts[i]) : int(ends[i])]
286+
else:
287+
result[coords] = None
288+
289+
return result
290+
222291

223292
@dataclass(frozen=True)
224293
class ShardingCodec(
@@ -505,7 +574,10 @@ async def _encode_partial_single(
505574
chunks_per_shard=chunks_per_shard,
506575
)
507576
shard_reader = shard_reader or _ShardReader.create_empty(chunks_per_shard)
508-
shard_dict = {k: shard_reader.get(k) for k in morton_order_iter(chunks_per_shard)}
577+
# Use vectorized lookup for better performance
578+
morton_coords = _morton_order(chunks_per_shard)
579+
chunk_coords_array = np.array(morton_coords, dtype=np.uint64)
580+
shard_dict = shard_reader.to_dict_vectorized(chunk_coords_array, morton_coords)
509581

510582
indexer = list(
511583
get_indexer(

0 commit comments

Comments
 (0)