Skip to content

Commit b40f987

Browse files
authored
Remove unneeded copies in rechunk (#878)
1 parent a899c65 commit b40f987

2 files changed

Lines changed: 4 additions & 7 deletions

File tree

cubed/core/ops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1100,7 +1100,8 @@ def _rechunk_plan(x, chunks, *, min_mem=None):
11001100
yield read_chunks, target_chunks_
11011101
else:
11021102
yield read_chunks, int_chunks
1103-
yield write_chunks, target_chunks_
1103+
if last_stage:
1104+
yield write_chunks, target_chunks_
11041105

11051106

11061107
def _rechunk(x, copy_chunks, target_chunks):

cubed/tests/test_rechunk.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ def test_rechunk_era5(
6060
if d.get("op_name", None) == "rechunk"
6161
]
6262

63-
# each stage has two ops due to intermediate store
64-
assert len(rechunks) == expected_num_stages * 2
63+
# number of rechunk copy ops is one more than the number of stages
64+
assert len(rechunks) == expected_num_stages + 1
6565

6666
max_input_blocks = max(
6767
d["pipeline"].config.num_input_blocks[0] for _, d in rechunks
@@ -95,13 +95,9 @@ def test_rechunk_era5_chunk_sizes(spec):
9595
rechunk_plan = list(_rechunk_plan(a, target_chunks))
9696
assert rechunk_plan == [
9797
((93, 721, 1440), (93, 240, 480)),
98-
((465, 240, 480), (465, 240, 480)),
9998
((465, 240, 480), (465, 120, 240)),
100-
((2325, 120, 240), (2325, 120, 240)),
10199
((2325, 120, 240), (2325, 40, 120)),
102-
((11625, 40, 120), (11625, 40, 120)),
103100
((11625, 40, 120), (11625, 20, 60)),
104-
((58125, 20, 60), (58125, 20, 60)),
105101
((58125, 20, 60), (58125, 10, 30)),
106102
((350640, 10, 30), (350640, 10, 10)),
107103
]

0 commit comments

Comments
 (0)