|
46 | 46 | from zarr.core.indexing import ( |
47 | 47 | BasicIndexer, |
48 | 48 | SelectorTuple, |
| 49 | + _morton_order, |
49 | 50 | c_order_iter, |
50 | 51 | get_indexer, |
51 | 52 | morton_order_iter, |
@@ -138,6 +139,45 @@ def get_chunk_slice(self, chunk_coords: tuple[int, ...]) -> tuple[int, int] | No |
138 | 139 | else: |
139 | 140 | return (int(chunk_start), int(chunk_start + chunk_len)) |
140 | 141 |
|
| 142 | + def get_chunk_slices_vectorized( |
| 143 | + self, chunk_coords_array: npt.NDArray[np.uint64] |
| 144 | + ) -> tuple[npt.NDArray[np.uint64], npt.NDArray[np.uint64], npt.NDArray[np.bool_]]: |
| 145 | + """Get chunk slices for multiple coordinates at once. |
| 146 | +
|
| 147 | + Parameters |
| 148 | + ---------- |
| 149 | + chunk_coords_array : ndarray of shape (n_chunks, n_dims) |
| 150 | + Array of chunk coordinates to look up. |
| 151 | +
|
| 152 | + Returns |
| 153 | + ------- |
| 154 | + starts : ndarray of shape (n_chunks,) |
| 155 | + Start byte positions for each chunk. |
| 156 | + ends : ndarray of shape (n_chunks,) |
| 157 | + End byte positions for each chunk. |
| 158 | + valid : ndarray of shape (n_chunks,) |
| 159 | + Boolean mask indicating which chunks are non-empty. |
| 160 | + """ |
| 161 | + # Localize coordinates via modulo (vectorized) |
| 162 | + shard_shape = np.array(self.offsets_and_lengths.shape[:-1], dtype=np.uint64) |
| 163 | + localized = chunk_coords_array % shard_shape |
| 164 | + |
| 165 | + # Build index tuple for advanced indexing |
| 166 | + index_tuple = tuple(localized[:, i] for i in range(localized.shape[1])) |
| 167 | + |
| 168 | + # Fetch all offsets and lengths at once |
| 169 | + offsets_and_lengths = self.offsets_and_lengths[index_tuple] |
| 170 | + starts = offsets_and_lengths[:, 0] |
| 171 | + lengths = offsets_and_lengths[:, 1] |
| 172 | + |
| 173 | + # Check for valid (non-empty) chunks |
| 174 | + valid = starts != MAX_UINT_64 |
| 175 | + |
| 176 | + # Compute end positions |
| 177 | + ends = starts + lengths |
| 178 | + |
| 179 | + return starts, ends, valid |
| 180 | + |
141 | 181 | def set_chunk_slice(self, chunk_coords: tuple[int, ...], chunk_slice: slice | None) -> None: |
142 | 182 | localized_chunk = self._localize_chunk(chunk_coords) |
143 | 183 | if chunk_slice is None: |
@@ -219,6 +259,35 @@ def __len__(self) -> int: |
219 | 259 | def __iter__(self) -> Iterator[tuple[int, ...]]: |
220 | 260 | return c_order_iter(self.index.offsets_and_lengths.shape[:-1]) |
221 | 261 |
|
| 262 | + def to_dict_vectorized( |
| 263 | + self, |
| 264 | + chunk_coords_array: npt.NDArray[np.uint64], |
| 265 | + chunk_coords_tuples: tuple[tuple[int, ...], ...], |
| 266 | + ) -> dict[tuple[int, ...], Buffer | None]: |
| 267 | + """Build a dict of chunk coordinates to buffers using vectorized lookup. |
| 268 | +
|
| 269 | + Parameters |
| 270 | + ---------- |
| 271 | + chunk_coords_array : ndarray of shape (n_chunks, n_dims) |
| 272 | + Array of chunk coordinates for vectorized index lookup. |
| 273 | + chunk_coords_tuples : tuple of tuples |
| 274 | + The same coordinates as tuples, used as dict keys to avoid conversion. |
| 275 | +
|
| 276 | + Returns |
| 277 | + ------- |
| 278 | + dict mapping chunk coordinate tuples to Buffer or None |
| 279 | + """ |
| 280 | + starts, ends, valid = self.index.get_chunk_slices_vectorized(chunk_coords_array) |
| 281 | + |
| 282 | + result: dict[tuple[int, ...], Buffer | None] = {} |
| 283 | + for i, coords in enumerate(chunk_coords_tuples): |
| 284 | + if valid[i]: |
| 285 | + result[coords] = self.buf[int(starts[i]) : int(ends[i])] |
| 286 | + else: |
| 287 | + result[coords] = None |
| 288 | + |
| 289 | + return result |
| 290 | + |
222 | 291 |
|
223 | 292 | @dataclass(frozen=True) |
224 | 293 | class ShardingCodec( |
@@ -505,7 +574,10 @@ async def _encode_partial_single( |
505 | 574 | chunks_per_shard=chunks_per_shard, |
506 | 575 | ) |
507 | 576 | shard_reader = shard_reader or _ShardReader.create_empty(chunks_per_shard) |
508 | | - shard_dict = {k: shard_reader.get(k) for k in morton_order_iter(chunks_per_shard)} |
| 577 | + # Use vectorized lookup for better performance |
| 578 | + morton_coords = _morton_order(chunks_per_shard) |
| 579 | + chunk_coords_array = np.array(morton_coords, dtype=np.uint64) |
| 580 | + shard_dict = shard_reader.to_dict_vectorized(chunk_coords_array, morton_coords) |
509 | 581 |
|
510 | 582 | indexer = list( |
511 | 583 | get_indexer( |
|
0 commit comments