diff --git a/cubed/core/ops.py b/cubed/core/ops.py index fbf66233..811435f5 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -1100,7 +1100,8 @@ def _rechunk_plan(x, chunks, *, min_mem=None): yield read_chunks, target_chunks_ else: yield read_chunks, int_chunks - yield write_chunks, target_chunks_ + if last_stage: + yield write_chunks, target_chunks_ def _rechunk(x, copy_chunks, target_chunks): diff --git a/cubed/tests/test_rechunk.py b/cubed/tests/test_rechunk.py index ab4e178b..74ce6a13 100644 --- a/cubed/tests/test_rechunk.py +++ b/cubed/tests/test_rechunk.py @@ -60,8 +60,8 @@ def test_rechunk_era5( if d.get("op_name", None) == "rechunk" ] - # each stage has two ops due to intermediate store - assert len(rechunks) == expected_num_stages * 2 + # number of rechunk copy ops is one more than the number of stages + assert len(rechunks) == expected_num_stages + 1 max_input_blocks = max( d["pipeline"].config.num_input_blocks[0] for _, d in rechunks @@ -95,13 +95,9 @@ def test_rechunk_era5_chunk_sizes(spec): rechunk_plan = list(_rechunk_plan(a, target_chunks)) assert rechunk_plan == [ ((93, 721, 1440), (93, 240, 480)), - ((465, 240, 480), (465, 240, 480)), ((465, 240, 480), (465, 120, 240)), - ((2325, 120, 240), (2325, 120, 240)), ((2325, 120, 240), (2325, 40, 120)), - ((11625, 40, 120), (11625, 40, 120)), ((11625, 40, 120), (11625, 20, 60)), - ((58125, 20, 60), (58125, 20, 60)), ((58125, 20, 60), (58125, 10, 30)), ((350640, 10, 30), (350640, 10, 10)), ]