diff --git a/cubed/core/ops.py b/cubed/core/ops.py index 1d50ff5fc..71cff5be2 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -18,6 +18,7 @@ from cubed.backend_array_api import namespace as nxp from cubed.core.array import CoreArray, check_array_specs, compute, gensym from cubed.core.plan import Plan, intermediate_store +from cubed.core.rechunk import multistage_regular_rechunking_plan from cubed.primitive.blockwise import blockwise as primitive_blockwise from cubed.primitive.blockwise import general_blockwise as primitive_general_blockwise from cubed.primitive.memory import get_buffer_copies @@ -39,7 +40,6 @@ from cubed.vendor.dask.array.utils import validate_axis from cubed.vendor.dask.blockwise import broadcast_dimensions from cubed.vendor.dask.utils import has_keyword -from cubed.vendor.rechunker.algorithm import multistage_rechunking_plan if TYPE_CHECKING: from cubed.array_api.array_object import Array @@ -1071,7 +1071,7 @@ def _rechunk_plan(x, chunks, *, min_mem=None): rechunker_max_mem = (spec.allowed_mem - spec.reserved_mem) // total_copies if min_mem is None: min_mem = min(rechunker_max_mem // 20, x.nbytes) - stages = multistage_rechunking_plan( + stages = multistage_regular_rechunking_plan( shape=x.shape, source_chunks=source_chunks, target_chunks=target_chunks, diff --git a/cubed/core/rechunk.py b/cubed/core/rechunk.py new file mode 100644 index 000000000..60be69faf --- /dev/null +++ b/cubed/core/rechunk.py @@ -0,0 +1,233 @@ +"""Core rechunking algorithm based on rechunker, but adapted for Cubed to support regular Zarr chunks.""" + +import logging +import warnings +from math import floor, prod +from typing import List, Optional, Sequence + +import numpy as np + +from cubed.vendor.rechunker.algorithm import ( + MAX_STAGES, + ExcessiveIOWarning, + _calculate_shared_chunks, + _MultistagePlan, + calculate_single_stage_io_ops, + consolidate_chunks, +) + +logger = logging.getLogger(__name__) + + +def verify_chunk_compatibility( + shape, + write_chunks, + target_chunks, +): + for n, wc, tc in zip(shape, write_chunks, target_chunks): + assert (wc == n) or (wc % tc == 0), ( + f"write chunks {write_chunks} do not evenly slice target chunks {target_chunks}, " + f"since {wc} is not a multiple of {tc}" + ) + + +def multspace(start: int, stop: int, num: int, endpoints: bool = False): + """ + Returns numbers that are roughly evenly-spaced along a log scale, + and where each number is an exact multiple of the smallest. + + Note that start and stop endpoints are not returned. + + The returned values will always be an exact multiple of the smaller + of start and stop. But the larger of start and stop will not necessarily + be a multiple of any of the returned values. + + Examples:: + + >>> multspace(1, 1000, 2) + [10, 100] + >>> multspace(1000, 1, 2) + [100, 10] + >>> multspace(24, 43800, 3) + [168, 1008, 7056] + """ + + if start < 1: + raise NotImplementedError(f"start must be 1 or more, but was {start}") + + if stop < 1: + raise NotImplementedError(f"stop must be 1 or more, but was {stop}") + + if num < 0: + raise NotImplementedError(f"num must be positive, but was {num}") + + if endpoints: + raise NotImplementedError("endpoints is not supported in multspace") + + if start > stop: + return list(reversed(multspace(stop, start, num))) + + return list(_multspace(start, stop, num))[1:-1] + + +def _multspace(start, stop, num): + vals = np.geomspace(start, stop, num + 2) + vint = 1 + for v in vals: + # next value is int multiple closest to geomspace value (rounded down) + vint = max(floor(v / vint) * vint, 1) + yield vint + + +def calculate_regular_stage_chunks( + read_chunks: tuple[int, ...], + write_chunks: tuple[int, ...], + stage_count: int = 1, +) -> list[tuple[int, ...]]: + """ + Calculate chunks after each stage of a multi-stage rechunking. + + Unlike `calculate_stage_chunks` in rechunker, this implementation + always returns intermediate chunks sizes that work with regularly + chunked Zarr arrays. + """ + stages = [] + for rc, wc in zip(read_chunks, write_chunks): + stages.append(multspace(rc, wc, num=stage_count - 1)) + return [tuple(chunks) for chunks in np.array(stages).T.tolist()] + + +def _fix_copy_chunks(shape, copy_chunks, target_chunks): + # if copy chunks are bigger than target chunks in a particular axis, then + # round them down to the largest multiple of the target so they are aligned + return tuple( + cc if (cc <= tc) or (cc == n) or (cc % tc == 0) else (cc // tc) * tc + for n, cc, tc in zip(shape, copy_chunks, target_chunks) + ) + + +def multistage_regular_rechunking_plan( + shape: Sequence[int], + source_chunks: Sequence[int], + target_chunks: Sequence[int], + itemsize: int, + min_mem: int, + max_mem: int, + consolidate_reads: bool = True, + consolidate_writes: bool = True, +) -> _MultistagePlan: + """Calculate a rechunking plan that can use multiple split/consolidate steps. + + For best results, max_mem should be significantly larger than min_mem (e.g., + 10x). Otherwise an excessive number of rechunking steps will be required. + """ + + ndim = len(shape) + if len(source_chunks) != ndim: + raise ValueError(f"source_chunks {source_chunks} must have length {ndim}") + if len(target_chunks) != ndim: + raise ValueError(f"target_chunks {target_chunks} must have length {ndim}") + + source_chunk_mem = itemsize * prod(source_chunks) + target_chunk_mem = itemsize * prod(target_chunks) + + if source_chunk_mem > max_mem: + raise ValueError( + f"Source chunk memory ({source_chunk_mem}) exceeds max_mem ({max_mem})" + ) + if target_chunk_mem > max_mem: + raise ValueError( + f"Target chunk memory ({target_chunk_mem}) exceeds max_mem ({max_mem})" + ) + + if max_mem < min_mem: # basic sanity check + raise ValueError( + f"max_mem ({max_mem}) cannot be smaller than min_mem ({min_mem})" + ) + + if consolidate_writes: + logger.debug( + f"consolidate_write_chunks({shape}, {target_chunks}, {itemsize}, {max_mem})" + ) + write_chunks = consolidate_chunks(shape, target_chunks, itemsize, max_mem) + else: + write_chunks = tuple(target_chunks) + + if consolidate_reads: + read_chunk_limits: List[Optional[int]] = [] + for sc, wc in zip(source_chunks, write_chunks): + limit: Optional[int] + if wc > sc: + # consolidate reads over this axis, up to the write chunk size + limit = wc + else: + # don't consolidate reads over this axis + limit = None + read_chunk_limits.append(limit) + + logger.debug( + f"consolidate_read_chunks({shape}, {source_chunks}, {itemsize}, {max_mem}, {read_chunk_limits})" + ) + read_chunks = consolidate_chunks( + shape, source_chunks, itemsize, max_mem, read_chunk_limits + ) + else: + read_chunks = tuple(source_chunks) + + prev_io_ops: Optional[float] = None + prev_plan: Optional[_MultistagePlan] = None + + # increase the number of stages until min_mem is exceeded + for stage_count in range(1, MAX_STAGES): + stage_chunks = calculate_regular_stage_chunks( + read_chunks, write_chunks, stage_count + ) + # adjust read_chunks to ensure they align with following stage + read_chunks = _fix_copy_chunks( + shape, read_chunks, (stage_chunks + [write_chunks])[0] + ) + pre_chunks = [read_chunks] + stage_chunks + post_chunks = stage_chunks + [write_chunks] + + int_chunks = [ + _calculate_shared_chunks(pre, post) + for pre, post in zip(pre_chunks, post_chunks) + ] + plan = list(zip(pre_chunks, int_chunks, post_chunks)) + + int_mem = min(itemsize * prod(chunks) for chunks in int_chunks) + if int_mem >= min_mem: + return plan # success! + + io_ops = sum( + calculate_single_stage_io_ops(shape, pre, post) + for pre, post in zip(pre_chunks, post_chunks) + ) + if prev_io_ops is not None and io_ops > prev_io_ops: + warnings.warn( + "Search for multi-stage rechunking plan terminated before " + "achieving the minimum memory requirement due to increasing IO " + f"requirements. Smallest intermediates have size {int_mem}. " + f"Consider decreasing min_mem ({min_mem}) or increasing " + f"({max_mem}) to find a more efficient plan.", + category=ExcessiveIOWarning, + stacklevel=2, + ) + assert prev_plan is not None + return prev_plan + + prev_io_ops = io_ops + prev_plan = plan + + raise AssertionError( + "Failed to find a feasible multi-staging rechunking scheme satisfying " + f"min_mem ({min_mem}) and max_mem ({max_mem}) constraints. " + "Please file a bug report on GitHub: " + "https://github.com/pangeo-data/rechunker/issues\n\n" + "Include the following debugging info:\n" + f"shape={shape}, source_chunks={source_chunks}, " + f"target_chunks={target_chunks}, itemsize={itemsize}, " + f"min_mem={min_mem}, max_mem={max_mem}, " + f"consolidate_reads={consolidate_reads}, " + f"consolidate_writes={consolidate_writes}" + ) diff --git a/cubed/tests/test_rechunk.py b/cubed/tests/test_rechunk.py index 1da1f07bc..ab4e178b9 100644 --- a/cubed/tests/test_rechunk.py +++ b/cubed/tests/test_rechunk.py @@ -2,7 +2,20 @@ import cubed import cubed as xp +from cubed._testing import assert_array_equal +from cubed.backend_array_api import namespace as nxp from cubed.core.ops import _store_array +from cubed.core.rechunk import ( + calculate_regular_stage_chunks, + multistage_regular_rechunking_plan, + multspace, + verify_chunk_compatibility, +) +from cubed.utils import itemsize +from cubed.vendor.rechunker.algorithm import ( + calculate_stage_chunks, + multistage_rechunking_plan, +) @pytest.mark.parametrize( @@ -13,10 +26,11 @@ "expected_max_output_blocks", ), [ - (None, 3, 16, 15), # multistage rechunk - more stages, lower fan in/out + (None, 5, 7, 9), # multistage rechunk - more stages, lower fan in/out (0, 1, 3771, 3460), # single stage rechunk - very high fan in/out ], ) +@pytest.mark.filterwarnings("ignore::UserWarning") def test_rechunk_era5( tmp_path, min_mem, @@ -78,12 +92,17 @@ def test_rechunk_era5_chunk_sizes(spec): from cubed.core.ops import _rechunk_plan - assert list(_rechunk_plan(a, target_chunks)) == [ - ((93, 721, 1440), (93, 173, 396)), - ((1447, 173, 396), (1447, 173, 396)), - ((1447, 173, 396), (1447, 41, 109)), - ((22528, 41, 109), (22528, 41, 109)), - ((22528, 41, 109), (22528, 10, 30)), + rechunk_plan = list(_rechunk_plan(a, target_chunks)) + assert rechunk_plan == [ + ((93, 721, 1440), (93, 240, 480)), + ((465, 240, 480), (465, 240, 480)), + ((465, 240, 480), (465, 120, 240)), + ((2325, 120, 240), (2325, 120, 240)), + ((2325, 120, 240), (2325, 40, 120)), + ((11625, 40, 120), (11625, 40, 120)), + ((11625, 40, 120), (11625, 20, 60)), + ((58125, 20, 60), (58125, 20, 60)), + ((58125, 20, 60), (58125, 10, 30)), ((350640, 10, 30), (350640, 10, 10)), ] @@ -107,3 +126,105 @@ def test_rechunk_and_store(): # store should not change number of tasks assert c.plan().num_tasks == num_tasks + + +def test_rechunk_hypothesis_generated_bug(): + rechunk_shapes = (tuple([1001, 1001]), (38, 376), (5, 146)) + shape, source_chunks, target_chunks = rechunk_shapes + + spec = cubed.Spec(allowed_mem=8000000 / 10) + a = xp.ones(shape, chunks=source_chunks, spec=spec) + + from cubed.core.ops import _rechunk_plan + + rechunk_plan = list(_rechunk_plan(a, target_chunks)) + assert rechunk_plan == [((30, 376), (15, 376)), ((15, 1001), (5, 146))] + for copy_chunks, target_chunks in rechunk_plan: + verify_chunk_compatibility(shape, copy_chunks, target_chunks) + + b = a.rechunk((5, 146)) + + assert_array_equal(b.compute(), nxp.ones(shape)) + + +@pytest.mark.parametrize( + ("start", "stop", "num", "expected"), + [ + (1, 1000, 2, [10, 100]), + (1000, 1, 2, [100, 10]), + (1, 1000, 0, []), + (25, 25, 1, [25]), + (24, 43800, 3, [144, 1008, 6048]), + ], +) +def test_multspace(start, stop, num, expected): + result = multspace(start, stop, num) + + assert len(result) == num + assert result == expected + + # check that each value is a multiple of the smallest + smallest = min(start, stop) + assert all(val % smallest == 0 for val in result) + + +def test_calculate_stage_chunks(): + # era5 + # shape = (350640, 721, 1440) + # source_chunks = (31, 721, 1440) + target_chunks = (350640, 10, 10) + + read_chunks = (93, 721, 1440) + + # cubed algorithm produces regular chunkings: + # 1395 and 22320 are multiples of 93 (from read chunks) + # 160 and 40 are multiples of 10 (from target chunks) + # 250 and 50 are multiples of 10 (from target chunks) + stage_chunks = calculate_regular_stage_chunks( + read_chunks, target_chunks, stage_count=3 + ) + assert stage_chunks == [(1395, 160, 250), (22320, 40, 50)] + + # whereas the rechunker algorithm does not produce regular chunkings + stage_chunks = calculate_stage_chunks(read_chunks, target_chunks, stage_count=3) + assert stage_chunks == [(1447, 173, 274), (22528, 41, 52)] + + +def test_multistage_rechunking_plan(): + # era5 + shape = (350640, 721, 1440) + source_chunks = (31, 721, 1440) + target_chunks = (350640, 10, 10) + min_mem = 25000000 + max_mem = 500000000 + + stages = multistage_rechunking_plan( + shape, source_chunks, target_chunks, itemsize(xp.float32), min_mem, max_mem + ) + assert stages == [ + ((93, 721, 1440), (93, 173, 396), (1447, 173, 396)), + ((1447, 173, 396), (1447, 41, 109), (22528, 41, 109)), + ((22528, 41, 109), (22528, 10, 30), (350640, 10, 30)), + ] + + with pytest.raises(AssertionError, match="do not evenly slice target chunks"): # noqa: PT012 + for stage in stages: + read_chunks, int_chunks, write_chunks = stage + verify_chunk_compatibility(shape, read_chunks, int_chunks) + + stages = multistage_regular_rechunking_plan( + shape, source_chunks, target_chunks, itemsize(xp.float32), min_mem, max_mem + ) + assert stages == [ + ((93, 721, 1440), (93, 240, 480), (465, 240, 480)), + ((465, 240, 480), (465, 120, 240), (2325, 120, 240)), + ((2325, 120, 240), (2325, 40, 120), (11625, 40, 120)), + ((11625, 40, 120), (11625, 20, 60), (58125, 20, 60)), + ((58125, 20, 60), (58125, 10, 30), (350640, 10, 30)), + ] + for i, stage in enumerate(stages): + last_stage = i == len(stages) - 1 + read_chunks, int_chunks, write_chunks = stage + verify_chunk_compatibility(shape, read_chunks, int_chunks) + if last_stage: + verify_chunk_compatibility(shape, write_chunks, target_chunks)