Skip to content

Commit dd146c0

Browse files
committed
Use the last dim first to take advantage of c-ordered linearization
1 parent fce8cd7 commit dd146c0

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

xarray/namedarray/parallelcompat.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -823,7 +823,7 @@ def preserve_chunks(
823823
... typesize=8,
824824
... previous_chunks=(128, 128, 1),
825825
... )
826-
(512, 256, 1)
826+
(128, 1024, 1)
827827
828828
>>> ChunkManagerEntrypoint.preserve_chunks(
829829
... chunks=("preserve", "preserve", 1),
@@ -861,6 +861,7 @@ def preserve_chunks(
861861
if isinstance(previous_chunk, tuple):
862862
# For uniform chunks just take the first item
863863
if previous_chunk[1:-1] == previous_chunk[:-2]:
864+
new_chunks[i] = previous_chunk[0]
864865
previous_chunk = previous_chunk[0]
865866
# For non-uniform chunks, leave them alone
866867
else:
@@ -880,14 +881,13 @@ def preserve_chunks(
880881
return chunks
881882

882883
while True:
883-
# Repeatedly look for the dim with the most chunks and multiply it by 2.
884+
# Repeatedly look for the last dim with more than one chunk and multiply it by 2.
884885
# Stop when:
885886
# 1a. we are larger than the target chunk size OR
886887
# 1b. we are within 50% of the target chunk size OR
887888
# 2. the chunk covers the entire array
888889

889890
num_chunks = np.array(shape) / max_chunks * auto_dims
890-
idx = np.argmax(num_chunks)
891891
chunk_bytes = np.prod(max_chunks) * typesize
892892

893893
if chunk_bytes > target or abs(chunk_bytes - target) / target < 0.5:
@@ -896,6 +896,8 @@ def preserve_chunks(
896896
if (num_chunks <= 1).all():
897897
break
898898

899+
idx = int(np.nonzero(num_chunks > 1)[0][-1])
900+
899901
new_chunks[idx] = min(new_chunks[idx] * 2, shape[idx])
900902
max_chunks[idx] = new_chunks[idx]
901903

xarray/tests/test_backends.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7439,7 +7439,7 @@ def test_chunking_consistency(chunks, tmp_path: Path) -> None:
74397439
@pytest.mark.parametrize(
74407440
"chunks,expected",
74417441
[
7442-
("preserve", (320, 320)),
7442+
("preserve", (160, 500)),
74437443
(-1, (500, 500)),
74447444
({}, (10, 10)),
74457445
({"x": "preserve"}, (500, 10)),

0 commit comments

Comments
 (0)