Skip to content

Commit f198dde

Browse files
committed
Multistage rechunking for regular chunk grids
1 parent 42181d3 commit f198dde

3 files changed

Lines changed: 356 additions & 9 deletions

File tree

cubed/core/ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from cubed.backend_array_api import namespace as nxp
1919
from cubed.core.array import CoreArray, check_array_specs, compute, gensym
2020
from cubed.core.plan import Plan, intermediate_store
21+
from cubed.core.rechunk import multistage_regular_rechunking_plan
2122
from cubed.primitive.blockwise import blockwise as primitive_blockwise
2223
from cubed.primitive.blockwise import general_blockwise as primitive_general_blockwise
2324
from cubed.primitive.memory import get_buffer_copies
@@ -39,7 +40,6 @@
3940
from cubed.vendor.dask.array.utils import validate_axis
4041
from cubed.vendor.dask.blockwise import broadcast_dimensions
4142
from cubed.vendor.dask.utils import has_keyword
42-
from cubed.vendor.rechunker.algorithm import multistage_rechunking_plan
4343

4444
if TYPE_CHECKING:
4545
from cubed.array_api.array_object import Array
@@ -1068,7 +1068,7 @@ def _rechunk_plan(x, chunks, *, min_mem=None):
10681068
rechunker_max_mem = (spec.allowed_mem - spec.reserved_mem) // total_copies
10691069
if min_mem is None:
10701070
min_mem = min(rechunker_max_mem // 20, x.nbytes)
1071-
stages = multistage_rechunking_plan(
1071+
stages = multistage_regular_rechunking_plan(
10721072
shape=x.shape,
10731073
source_chunks=source_chunks,
10741074
target_chunks=target_chunks,

cubed/core/rechunk.py

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
"""Core rechunking algorithm based on rechunker, but adapted for Cubed to support regular Zarr chunks."""
2+
3+
import logging
4+
import warnings
5+
from math import floor, prod
6+
from typing import List, Optional, Sequence
7+
8+
import numpy as np
9+
10+
from cubed.vendor.rechunker.algorithm import (
11+
MAX_STAGES,
12+
ExcessiveIOWarning,
13+
_calculate_shared_chunks,
14+
_MultistagePlan,
15+
calculate_single_stage_io_ops,
16+
consolidate_chunks,
17+
)
18+
19+
logger = logging.getLogger(__name__)
20+
21+
22+
def verify_chunk_compatibility(
23+
shape,
24+
write_chunks,
25+
target_chunks,
26+
):
27+
for n, wc, tc in zip(shape, write_chunks, target_chunks):
28+
assert (wc == n) or (wc % tc == 0), (
29+
f"write chunks {write_chunks} do not evenly slice target chunks {target_chunks}, "
30+
f"since {wc} is not a multiple of {tc}"
31+
)
32+
33+
34+
def multspace(start: int, stop: int, num: int, endpoints: bool = False):
35+
"""
36+
Returns numbers that are roughly evenly-spaced along a log scale,
37+
and where each number is an exact multiple of the smallest.
38+
39+
Note that start and stop endpoints are not returned.
40+
41+
The returned values will always be an exact multiple of the smaller
42+
of start and stop. But the larger of start and stop will not necessarily
43+
be a multiple of any of the returned values.
44+
45+
Examples::
46+
47+
>>> multspace(1, 1000, 2)
48+
[10, 100]
49+
>>> multspace(1000, 1, 2)
50+
[100, 10]
51+
>>> multspace(24, 43800, 3)
52+
[168, 1008, 7056]
53+
"""
54+
55+
if start < 1:
56+
raise NotImplementedError(f"start must be 1 or more, but was {start}")
57+
58+
if stop < 1:
59+
raise NotImplementedError(f"stop must be 1 or more, but was {stop}")
60+
61+
if num < 0:
62+
raise NotImplementedError(f"num must be positive, but was {num}")
63+
64+
if endpoints:
65+
raise NotImplementedError("endpoints is not supported in multspace")
66+
67+
if start > stop:
68+
return list(reversed(multspace(stop, start, num)))
69+
70+
return list(_multspace(start, stop, num))[1:-1]
71+
72+
73+
def _multspace(start, stop, num):
74+
vals = np.geomspace(start, stop, num + 2)
75+
vint = 1
76+
for v in vals:
77+
vint = floor(v / vint) * vint
78+
yield vint
79+
80+
81+
def calculate_regular_stage_chunks(
82+
read_chunks: tuple[int, ...],
83+
write_chunks: tuple[int, ...],
84+
stage_count: int = 1,
85+
) -> list[tuple[int, ...]]:
86+
"""
87+
Calculate chunks after each stage of a multi-stage rechunking.
88+
89+
Unlike `calculate_stage_chunks` in rechunker, this implementation
90+
always returns intermediate chunks sizes that work with regularly
91+
chunked Zarr arrays.
92+
"""
93+
stages = []
94+
for rc, wc in zip(read_chunks, write_chunks):
95+
stages.append(multspace(rc, wc, num=stage_count - 1))
96+
return [tuple(chunks) for chunks in np.array(stages).T.tolist()]
97+
98+
99+
def _fix_copy_chunks(shape, copy_chunks, target_chunks):
100+
# if copy chunks are bigger than target chunks in a particular axis, then
101+
# round them down to the largest multiple of the target so they are aligned
102+
return tuple(
103+
cc if (cc <= tc) or (cc == n) or (cc % tc == 0) else (cc // tc) * tc
104+
for n, cc, tc in zip(shape, copy_chunks, target_chunks)
105+
)
106+
107+
108+
def multistage_regular_rechunking_plan(
109+
shape: Sequence[int],
110+
source_chunks: Sequence[int],
111+
target_chunks: Sequence[int],
112+
itemsize: int,
113+
min_mem: int,
114+
max_mem: int,
115+
consolidate_reads: bool = True,
116+
consolidate_writes: bool = True,
117+
) -> _MultistagePlan:
118+
"""Calculate a rechunking plan that can use multiple split/consolidate steps.
119+
120+
For best results, max_mem should be significantly larger than min_mem (e.g.,
121+
10x). Otherwise an excessive number of rechunking steps will be required.
122+
"""
123+
124+
ndim = len(shape)
125+
if len(source_chunks) != ndim:
126+
raise ValueError(f"source_chunks {source_chunks} must have length {ndim}")
127+
if len(target_chunks) != ndim:
128+
raise ValueError(f"target_chunks {target_chunks} must have length {ndim}")
129+
130+
source_chunk_mem = itemsize * prod(source_chunks)
131+
target_chunk_mem = itemsize * prod(target_chunks)
132+
133+
if source_chunk_mem > max_mem:
134+
raise ValueError(
135+
f"Source chunk memory ({source_chunk_mem}) exceeds max_mem ({max_mem})"
136+
)
137+
if target_chunk_mem > max_mem:
138+
raise ValueError(
139+
f"Target chunk memory ({target_chunk_mem}) exceeds max_mem ({max_mem})"
140+
)
141+
142+
if max_mem < min_mem: # basic sanity check
143+
raise ValueError(
144+
f"max_mem ({max_mem}) cannot be smaller than min_mem ({min_mem})"
145+
)
146+
147+
if consolidate_writes:
148+
logger.debug(
149+
f"consolidate_write_chunks({shape}, {target_chunks}, {itemsize}, {max_mem})"
150+
)
151+
write_chunks = consolidate_chunks(shape, target_chunks, itemsize, max_mem)
152+
else:
153+
write_chunks = tuple(target_chunks)
154+
155+
if consolidate_reads:
156+
read_chunk_limits: List[Optional[int]] = []
157+
for sc, wc in zip(source_chunks, write_chunks):
158+
limit: Optional[int]
159+
if wc > sc:
160+
# consolidate reads over this axis, up to the write chunk size
161+
limit = wc
162+
else:
163+
# don't consolidate reads over this axis
164+
limit = None
165+
read_chunk_limits.append(limit)
166+
167+
logger.debug(
168+
f"consolidate_read_chunks({shape}, {source_chunks}, {itemsize}, {max_mem}, {read_chunk_limits})"
169+
)
170+
read_chunks = consolidate_chunks(
171+
shape, source_chunks, itemsize, max_mem, read_chunk_limits
172+
)
173+
else:
174+
read_chunks = tuple(source_chunks)
175+
176+
prev_io_ops: Optional[float] = None
177+
prev_plan: Optional[_MultistagePlan] = None
178+
179+
# increase the number of stages until min_mem is exceeded
180+
for stage_count in range(1, MAX_STAGES):
181+
stage_chunks = calculate_regular_stage_chunks(
182+
read_chunks, write_chunks, stage_count
183+
)
184+
# adjust read_chunks to ensure they align with following stage
185+
read_chunks = _fix_copy_chunks(
186+
shape, read_chunks, (stage_chunks + [write_chunks])[0]
187+
)
188+
pre_chunks = [read_chunks] + stage_chunks
189+
post_chunks = stage_chunks + [write_chunks]
190+
191+
int_chunks = [
192+
_calculate_shared_chunks(pre, post)
193+
for pre, post in zip(pre_chunks, post_chunks)
194+
]
195+
plan = list(zip(pre_chunks, int_chunks, post_chunks))
196+
197+
int_mem = min(itemsize * prod(chunks) for chunks in int_chunks)
198+
if int_mem >= min_mem:
199+
return plan # success!
200+
201+
io_ops = sum(
202+
calculate_single_stage_io_ops(shape, pre, post)
203+
for pre, post in zip(pre_chunks, post_chunks)
204+
)
205+
if prev_io_ops is not None and io_ops > prev_io_ops:
206+
warnings.warn(
207+
"Search for multi-stage rechunking plan terminated before "
208+
"achieving the minimum memory requirement due to increasing IO "
209+
f"requirements. Smallest intermediates have size {int_mem}. "
210+
f"Consider decreasing min_mem ({min_mem}) or increasing "
211+
f"({max_mem}) to find a more efficient plan.",
212+
category=ExcessiveIOWarning,
213+
stacklevel=2,
214+
)
215+
assert prev_plan is not None
216+
return prev_plan
217+
218+
prev_io_ops = io_ops
219+
prev_plan = plan
220+
221+
raise AssertionError(
222+
"Failed to find a feasible multi-staging rechunking scheme satisfying "
223+
f"min_mem ({min_mem}) and max_mem ({max_mem}) constraints. "
224+
"Please file a bug report on GitHub: "
225+
"https://github.com/pangeo-data/rechunker/issues\n\n"
226+
"Include the following debugging info:\n"
227+
f"shape={shape}, source_chunks={source_chunks}, "
228+
f"target_chunks={target_chunks}, itemsize={itemsize}, "
229+
f"min_mem={min_mem}, max_mem={max_mem}, "
230+
f"consolidate_reads={consolidate_reads}, "
231+
f"consolidate_writes={consolidate_writes}"
232+
)

0 commit comments

Comments
 (0)