Skip to content

Commit fce8cd7

Browse files
committed
For non-uniform chunks just pass them back as is
1 parent d7c9644 commit fce8cd7

File tree

1 file changed

+52
-36
lines changed

1 file changed

+52
-36
lines changed

xarray/namedarray/parallelcompat.py

Lines changed: 52 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -797,13 +797,13 @@ def preserve_chunks(
797797
"""Quickly determine optimal chunks close to target size but never splitting
798798
previous_chunks.
799799
800-
This takes in a chunks argument potentially containing ``"preserve"`` for all
801-
dimensions (if scalar) or several dimensions (if tuple). This function
802-
replaces ``"preserver"`` with concrete dimension sizes that try
803-
to get chunks to be close to certain size in bytes, provided by the ``target=``
800+
This takes in a chunks argument potentially containing ``"preserve"`` for several
801+
dimensions. This function replaces ``"preserve"`` with concrete dimension sizes that
802+
try to get chunks to be close to certain size in bytes, provided by the ``target=``
804803
keyword. Any dimensions marked as ``"preserve"`` will potentially be multiplied
805804
by some factor to get close to the byte target, while never splitting
806-
``previous_chunks``.
805+
``previous_chunks``. If chunks are non-uniform along a particular dimension
806+
then that dimension will always use exactly ``previous_chunks``.
807807
808808
Examples
809809
--------
@@ -825,55 +825,70 @@ def preserve_chunks(
825825
... )
826826
(512, 256, 1)
827827
828+
>>> ChunkManagerEntrypoint.preserve_chunks(
829+
... chunks=("preserve", "preserve", 1),
830+
... shape=(1280, 1280, 20),
831+
... target=1 * 1024 * 1024,
832+
... typesize=8,
833+
... previous_chunks=((128,) * 10, (128, 256, 256, 512), (1,) * 20),
834+
... )
835+
(256, (128, 256, 256, 512), 1)
836+
828837
Parameters
829838
----------
830-
chunks: tuple[int | str | tuple, ...]
839+
chunks: tuple[int | str | tuple[int], ...]
831840
A tuple of either dimensions or tuples of explicit chunk dimensions
832-
Some entries should be "preserve". Any explicit dimensions must match or
833-
be multiple of ``previous_chunks``
841+
Some entries should be "preserve".
834842
shape: tuple[int]
835843
The shape of the array
836844
target: int
837845
The target size of the chunk in bytes.
838846
typesize: int
839847
The size, in bytes, of each element of the chunk.
840-
previous_chunks: tuple[int]
841-
Size of chunks being preserved. Expressed as a tuple of ints which matches
842-
the way chunks are encoded in Zarr.
843-
"""
844-
# pop the first item off in case it's a tuple of tuples
845-
preferred_chunks = np.array(
846-
[c if isinstance(c, int) else c[0] for c in previous_chunks]
847-
)
848-
849-
# "preserve" stays as "preserve"
850-
# None or empty tuple means match preferred_chunks
851-
# -1 means whole dim is in one chunk
852-
desired_chunks = np.array(
853-
[
854-
c or preferred_chunks[i] if c != -1 else shape[i]
855-
for i, c in enumerate(chunks)
856-
]
857-
)
858-
preserve_chunks = desired_chunks == "preserve"
859-
860-
if not preserve_chunks.any():
848+
previous_chunks: tuple[int | tuple[int], ...]
849+
Size of chunks being preserved. Expressed as a tuple of ints or tuple
850+
of tuple of ints.
851+
"""
852+
new_chunks = [*previous_chunks]
853+
auto_dims = [c == "preserve" for c in chunks]
854+
max_chunks = np.array(shape)
855+
for i, previous_chunk in enumerate(previous_chunks):
856+
chunk = chunks[i]
857+
if chunk == -1:
858+
# -1 means whole dim is in one chunk
859+
new_chunks[i] = shape[i]
860+
else:
861+
if isinstance(previous_chunk, tuple):
862+
# For uniform chunks just take the first item
863+
if previous_chunk[1:-1] == previous_chunk[:-2]:
864+
previous_chunk = previous_chunk[0]
865+
# For non-uniform chunks, leave them alone
866+
else:
867+
auto_dims[i] = False
868+
max_chunks[i] = max(previous_chunk)
869+
870+
if isinstance(previous_chunk, int):
871+
# preserve, None or () means we want to track previous chunk
872+
if chunk == "preserve" or not chunk:
873+
max_chunks[i] = previous_chunk
874+
# otherwise use the explicitly provided chunk
875+
else:
876+
new_chunks[i] = chunk
877+
max_chunks[i] = chunk if isinstance(chunk, int) else max(chunk)
878+
879+
if not any(auto_dims):
861880
return chunks
862881

863-
new_chunks = np.where(preserve_chunks, preferred_chunks, desired_chunks).astype(
864-
int
865-
)
866-
867882
while True:
868883
# Repeatedly look for the dim with the most chunks and multiply it by 2.
869884
# Stop when:
870885
# 1a. we are larger than the target chunk size OR
871886
# 1b. we are within 50% of the target chunk size OR
872887
# 2. the chunk covers the entire array
873888

874-
num_chunks = np.array(shape) / new_chunks * preserve_chunks
889+
num_chunks = np.array(shape) / max_chunks * auto_dims
875890
idx = np.argmax(num_chunks)
876-
chunk_bytes = np.prod(new_chunks) * typesize
891+
chunk_bytes = np.prod(max_chunks) * typesize
877892

878893
if chunk_bytes > target or abs(chunk_bytes - target) / target < 0.5:
879894
break
@@ -882,5 +897,6 @@ def preserve_chunks(
882897
break
883898

884899
new_chunks[idx] = min(new_chunks[idx] * 2, shape[idx])
900+
max_chunks[idx] = new_chunks[idx]
885901

886-
return tuple(int(x) for x in new_chunks)
902+
return tuple(new_chunks)

0 commit comments

Comments
 (0)