|
1 | 1 | from types import EllipsisType |
2 | | -from typing import TYPE_CHECKING, Any, TypeAlias, cast |
| 2 | +from typing import TYPE_CHECKING, Any, TypeAlias, TypeGuard, cast |
3 | 3 |
|
4 | 4 | import numpy as np |
| 5 | +from zarr.codecs import BytesCodec, TransposeCodec |
| 6 | +from zarr.core.metadata.v3 import ArrayV3Metadata |
5 | 7 |
|
6 | 8 | from virtualizarr.manifests.array_api import expand_dims |
7 | 9 | from virtualizarr.manifests.manifest import ChunkManifest |
@@ -172,33 +174,62 @@ def apply_selection( |
172 | 174 | raise TypeError(f"Invalid indexer type: {indexer_1d}") |
173 | 175 | narrowed_indexers.append(indexer_1d) |
174 | 176 |
|
| 177 | + sub_chunk_axis = _uncompressed_sub_chunk_axis(marr.metadata) |
| 178 | + |
175 | 179 | new_shape: list[int] = [] |
176 | 180 | new_chunks: list[int] = [] |
177 | 181 | chunk_grid_selectors: list[int | slice] = [] |
178 | 182 | kept_axes: list[int] = [] |
| 183 | + # At most one sub-chunk axis (whichever axis has the largest byte stride in storage). |
| 184 | + # The byte adjustment is uniform across every surviving chunk, since chunks share layout. |
| 185 | + sub_chunk_byte_adjust: tuple[int, int] | None = None |
179 | 186 | for axis, (axis_length, chunk_size, indexer_1d) in enumerate( |
180 | 187 | zip(marr.shape, marr.chunks, narrowed_indexers, strict=True) |
181 | 188 | ): |
182 | | - chunk_grid_selector, new_axis_length = _compute_chunk_aligned_selection_1d( |
183 | | - indexer_1d, axis_length=axis_length, chunk_size=chunk_size |
184 | | - ) |
| 189 | + chunk_grid_selector: int | slice |
| 190 | + if axis == sub_chunk_axis and _is_sub_chunk_slice( |
| 191 | + indexer_1d, axis_length, chunk_size |
| 192 | + ): |
| 193 | + chunk_grid_selector, new_axis_length, sub_chunk_byte_adjust = ( |
| 194 | + _compute_sub_chunk_axis_selection( |
| 195 | + indexer_1d, |
| 196 | + axis_length=axis_length, |
| 197 | + chunk_size=chunk_size, |
| 198 | + other_axis_chunks=tuple( |
| 199 | + c for i, c in enumerate(marr.chunks) if i != axis |
| 200 | + ), |
| 201 | + itemsize=marr.dtype.itemsize, |
| 202 | + ) |
| 203 | + ) |
| 204 | + new_chunks_for_axis = new_axis_length |
| 205 | + else: |
| 206 | + chunk_grid_selector, new_axis_length = _compute_chunk_aligned_selection_1d( |
| 207 | + indexer_1d, axis_length=axis_length, chunk_size=chunk_size |
| 208 | + ) |
| 209 | + new_chunks_for_axis = chunk_size |
| 210 | + |
185 | 211 | chunk_grid_selectors.append(chunk_grid_selector) |
186 | | - # int selectors drop the axis from the output array |
| 212 | + # int indexers drop the axis from the output array; slices preserve it (including the |
| 213 | + # sub-chunk path, which uses a length-1 chunk-grid slice selector). |
187 | 214 | if not isinstance(indexer_1d, int): |
188 | 215 | new_shape.append(new_axis_length) |
189 | | - new_chunks.append(chunk_size) |
| 216 | + new_chunks.append(new_chunks_for_axis) |
190 | 217 | kept_axes.append(axis) |
191 | 218 |
|
192 | 219 | chunk_grid_selectors_tuple = tuple(chunk_grid_selectors) |
193 | 220 |
|
194 | | - # short-circuit if every axis selects the whole chunk grid via a slice (a no-op) |
195 | | - if all( |
| 221 | + # short-circuit if every axis selects the whole chunk grid via a slice (a no-op). |
| 222 | + # A pending sub-chunk byte adjustment is real work even if its single source chunk |
| 223 | + # happens to span the whole chunk grid along that axis, so don't short-circuit then. |
| 224 | + if sub_chunk_byte_adjust is None and all( |
196 | 225 | isinstance(cgs, slice) and cgs == slice(0, dim, 1) |
197 | 226 | for cgs, dim in zip(chunk_grid_selectors_tuple, marr.manifest.shape_chunk_grid) |
198 | 227 | ): |
199 | 228 | return marr |
200 | 229 |
|
201 | 230 | new_manifest = _subset_manifest(marr.manifest, chunk_grid_selectors_tuple) |
| 231 | + if sub_chunk_byte_adjust is not None: |
| 232 | + new_manifest = _shift_manifest_byte_ranges(new_manifest, *sub_chunk_byte_adjust) |
202 | 233 | old_dimension_names = marr.metadata.dimension_names |
203 | 234 | # zarr's dimension_names is tuple[str | None, ...] but copy_and_replace_metadata's |
204 | 235 | # type hint says Iterable[str]; the runtime handles None entries fine, so cast through. |
@@ -265,6 +296,38 @@ def _compute_chunk_aligned_selection_1d( |
265 | 296 | return slice(chunk_start, chunk_stop, 1), stop - start |
266 | 297 |
|
267 | 298 |
|
| 299 | +def _compute_sub_chunk_axis_selection( |
| 300 | + indexer_1d: slice, |
| 301 | + axis_length: int, |
| 302 | + chunk_size: int, |
| 303 | + other_axis_chunks: tuple[int, ...], |
| 304 | + itemsize: int, |
| 305 | +) -> tuple[slice, int, tuple[int, int]]: |
| 306 | + """ |
| 307 | + Translate a sub-chunk slice along the eligible (largest-stride) storage axis into a |
| 308 | + chunk-grid selector, an output axis length, and a uniform byte adjustment |
| 309 | + ``(offset_delta, new_chunk_byte_length)`` applied to every surviving chunk reference. |
| 310 | +
|
| 311 | + Callers must have already confirmed that this slice is sub-chunk-eligible via |
| 312 | + ``_is_sub_chunk_slice`` and that the array is uncompressed via |
| 313 | + ``_uncompressed_sub_chunk_axis``. |
| 314 | + """ |
| 315 | + start, stop, _ = indexer_1d.indices(axis_length) |
| 316 | + chunk_index = start // chunk_size |
| 317 | + new_axis_length = stop - start |
| 318 | + # Bytes per index step along this axis within one chunk is the product of every |
| 319 | + # *other* axis's chunk size, times itemsize. Order doesn't matter since the product |
| 320 | + # is commutative. |
| 321 | + stride_bytes = int(np.prod(other_axis_chunks)) * itemsize |
| 322 | + inner_offset_bytes = (start - chunk_index * chunk_size) * stride_bytes |
| 323 | + sub_chunk_byte_adjust = (inner_offset_bytes, new_axis_length * stride_bytes) |
| 324 | + return ( |
| 325 | + slice(chunk_index, chunk_index + 1, 1), |
| 326 | + new_axis_length, |
| 327 | + sub_chunk_byte_adjust, |
| 328 | + ) |
| 329 | + |
| 330 | + |
268 | 331 | def _subset_manifest( |
269 | 332 | manifest: ChunkManifest, chunk_grid_selectors: tuple[int | slice, ...] |
270 | 333 | ) -> ChunkManifest: |
@@ -323,3 +386,86 @@ def _subset_manifest( |
323 | 386 | inlined=new_inlined, |
324 | 387 | validate_paths=False, |
325 | 388 | ) |
| 389 | + |
| 390 | + |
| 391 | +def _uncompressed_sub_chunk_axis(metadata: ArrayV3Metadata) -> int | None: |
| 392 | + """ |
| 393 | + Return the axis along which sub-chunk slicing is implementable for this array, or |
| 394 | + ``None`` if the codec stack disqualifies it. |
| 395 | +
|
| 396 | + Sub-chunk slicing rewrites an existing chunk reference's byte offset and length, |
| 397 | + so it only works when chunk bytes are raw element values in a fixed memory order — |
| 398 | + i.e., no compression, no value transforms, no checksums. The eligible codec stacks |
| 399 | + are: |
| 400 | +
|
| 401 | + - ``[BytesCodec]`` — C-order layout; the axis with the largest byte stride is axis 0. |
| 402 | + - ``[TransposeCodec(order=perm), BytesCodec]`` — stored layout is the logical array |
| 403 | + permuted by ``perm``; the axis with the largest byte stride in storage is logical |
| 404 | + axis ``perm[0]``. For the F-order case ``perm = (n-1, n-2, ..., 0)`` this picks out |
| 405 | + the last axis. |
| 406 | + """ |
| 407 | + codecs = metadata.codecs |
| 408 | + if len(codecs) == 1 and isinstance(codecs[0], BytesCodec): |
| 409 | + return 0 |
| 410 | + if ( |
| 411 | + len(codecs) == 2 |
| 412 | + and isinstance(codecs[0], TransposeCodec) |
| 413 | + and isinstance(codecs[1], BytesCodec) |
| 414 | + ): |
| 415 | + return int(codecs[0].order[0]) |
| 416 | + return None |
| 417 | + |
| 418 | + |
| 419 | +def _is_sub_chunk_slice( |
| 420 | + indexer_1d: int | slice, axis_length: int, chunk_size: int |
| 421 | +) -> TypeGuard[slice]: |
| 422 | + """ |
| 423 | + True iff this is a slice that should take the sub-chunk path: step == 1, non-empty, |
| 424 | + fits entirely within one source chunk, and is NOT already chunk-aligned (chunk-aligned |
| 425 | + slices go through the simpler aligned path). |
| 426 | +
|
| 427 | + Typed as ``TypeGuard[slice]`` so callers can pass the narrowed indexer straight into |
| 428 | + helpers that take a ``slice``. |
| 429 | + """ |
| 430 | + if not isinstance(indexer_1d, slice): |
| 431 | + return False |
| 432 | + start, stop, step = indexer_1d.indices(axis_length) |
| 433 | + if step != 1 or start >= stop: |
| 434 | + return False |
| 435 | + # chunk-aligned slices are handled by _compute_chunk_aligned_selection_1d |
| 436 | + aligned = start % chunk_size == 0 and ( |
| 437 | + stop == axis_length or stop % chunk_size == 0 |
| 438 | + ) |
| 439 | + if aligned: |
| 440 | + return False |
| 441 | + # contained in a single source chunk? |
| 442 | + return start // chunk_size == (stop - 1) // chunk_size |
| 443 | + |
| 444 | + |
| 445 | +def _shift_manifest_byte_ranges( |
| 446 | + manifest: ChunkManifest, offset_delta: int, new_length: int |
| 447 | +) -> ChunkManifest: |
| 448 | + """ |
| 449 | + Return a new ``ChunkManifest`` whose virtual chunk references point to a uniform |
| 450 | + sub-range of each original chunk: ``offset += offset_delta`` and ``length = new_length``. |
| 451 | +
|
| 452 | + Used by the uncompressed-axis-0 sub-chunk path, where every surviving chunk shares the |
| 453 | + same byte layout and therefore the same byte adjustment. |
| 454 | + """ |
| 455 | + new_offsets = cast( |
| 456 | + "np.ndarray[Any, np.dtype[np.uint64]]", |
| 457 | + manifest._offsets + np.uint64(offset_delta), |
| 458 | + ) |
| 459 | + new_lengths = cast( |
| 460 | + "np.ndarray[Any, np.dtype[np.uint64]]", |
| 461 | + np.full_like(manifest._lengths, np.uint64(new_length)), |
| 462 | + ) |
| 463 | + # paths and any inlined-chunk dict carry through unchanged: inlined chunks aren't |
| 464 | + # involved here (this path is only taken for uncompressed virtual references). |
| 465 | + return ChunkManifest.from_arrays( |
| 466 | + paths=manifest._paths, |
| 467 | + offsets=new_offsets, |
| 468 | + lengths=new_lengths, |
| 469 | + inlined=dict(manifest._inlined), |
| 470 | + validate_paths=False, |
| 471 | + ) |
0 commit comments