Skip to content

Commit 4ccbddb

Browse files
authored
Memory-safe rechunk, preview chunk budget, plot improvements (#1075)
* Add design spec for hypsometric_integral in zonal.py * Update hypsometric_integral spec with review fixes Add column/rasterize_kw params, fix accessor namespace to .xrs, clarify nodata semantics, specify float64 output dtype, add list-of-pairs zones support, note dask chunk alignment strategy. * Add implementation plan for hypsometric_integral * Add .claude/worktrees/ to .gitignore * Add memory-safe rechunk, preview chunk budget, and plot ergonomics rechunk_no_shuffle now accepts Datasets and re-opens unmodified zarr stores with larger chunks instead of layering a rechunk-merge graph. preview() derives a per-task memory budget from the active dask cluster and re-opens zarr sources when chunks exceed it. Accessor gains Dataset.xrs.plot() for subplot grids and DataArray plot auto-computes dask arrays, sets equal aspect, and avoids kwargs mutation. Removes fused_overlap/multi_overlap from the Dataset accessor (the underlying functions only accept DataArrays). * Skip zarr-dependent rechunk tests when zarr is not installed The zarr tests called xr.open_zarr / to_zarr which fails on CI environments without zarr (e.g. macOS 3.14). The failure triggered fail-fast cancellation of all other matrix jobs. * Remove zarr re-open logic from rechunk and preview Drop _is_unmodified_zarr, _reopen_preview_chunks, and _preview_chunk_budget. Dataset rechunk now uses ds.chunk() instead of re-opening the zarr store under the hood.
1 parent ab16fb6 commit 4ccbddb

File tree

6 files changed

+124
-66
lines changed

6 files changed

+124
-66
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,4 @@ dmypy.json
9797
.asv/
9898
xrspatial-examples/
9999
*.zarr/
100+
.claude/worktrees/

examples/dask/distributed_reprojection.ipynb

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,11 @@
3535
"import xarray as xr\n",
3636
"import dask\n",
3737
"import dask.array as da\n",
38-
"import matplotlib.pyplot as plt\n",
38+
"\n",
3939
"from dask.distributed import Client, LocalCluster\n",
4040
"from pathlib import Path\n",
4141
"\n",
42+
"import xrspatial\n",
4243
"from xrspatial import reproject"
4344
]
4445
},
@@ -56,9 +57,9 @@
5657
"outputs": [],
5758
"source": [
5859
"cluster = LocalCluster(\n",
59-
" n_workers=20,\n",
60-
" threads_per_worker=1,\n",
61-
" memory_limit=\"2GB\",\n",
60+
" n_workers=4,\n",
61+
" threads_per_worker=2,\n",
62+
" memory_limit=\"10GB\",\n",
6263
")\n",
6364
"client = Client(cluster)\n",
6465
"client"
@@ -88,24 +89,15 @@
8889
"source": [
8990
"ZARR_PATH = Path.home() / \"elevation\" / \"usgs10m_dem_c6.zarr\"\n",
9091
"\n",
91-
"ds = xr.open_zarr(ZARR_PATH)\n",
92+
"ds = xr.open_zarr(ZARR_PATH).xrs.rechunk_no_shuffle(target_mb=512)\n",
9293
"ds"
9394
]
9495
},
95-
{
96-
"cell_type": "code",
97-
"execution_count": null,
98-
"metadata": {},
99-
"outputs": [],
100-
"source": [
101-
"ds.xrs.preview().plot()"
102-
]
103-
},
10496
{
10597
"cell_type": "markdown",
10698
"metadata": {},
10799
"source": [
108-
"Nothing has been read from disk yet. The repr above shows the Dask task graph backing the array. Each chunk is 2048 x 2048 pixels.\n",
100+
"Nothing has been read from disk yet. The repr above shows the Dask task graph backing the array. `rechunk_no_shuffle` detects the Zarr source and re-opens it with larger chunks, so each dask task reads multiple storage chunks in one call. This keeps the task graph small even for a 29 TB store.\n",
109101
"\n",
110102
"Let's clip to Colorado. Good mix of flat plains and mountains, and small enough to finish in a reasonable time."
111103
]
@@ -146,12 +138,14 @@
146138
]
147139
},
148140
{
149-
"cell_type": "code",
150-
"execution_count": null,
141+
"cell_type": "markdown",
151142
"metadata": {},
152-
"outputs": [],
153143
"source": [
154-
"dem.xrs.preview().plot()"
144+
"## Transform\n",
145+
"\n",
146+
"The source CRS is EPSG:4269 (NAD83, geographic lat/lon). We'll reproject to EPSG:5070 (NAD83 / Conus Albers Equal Area Conic), which gives equal-area cells in meters. That matters any time pixel area feeds into a calculation, like drainage area or cut/fill volumes, and it makes the DEM compatible with other projected datasets.\n",
147+
"\n",
148+
"`xrspatial.reproject` handles Dask arrays natively. It builds a lazy task graph where each output chunk is reprojected independently using numba-JIT'd resampling kernels."
155149
]
156150
},
157151
{
@@ -162,7 +156,9 @@
162156
"\n",
163157
"The source CRS is EPSG:4269 (NAD83, geographic lat/lon). We'll reproject to EPSG:5070 (NAD83 / Conus Albers Equal Area Conic), which gives equal-area cells in meters. That matters any time pixel area feeds into a calculation, like drainage area or cut/fill volumes, and it makes the DEM compatible with other projected datasets.\n",
164158
"\n",
165-
"`xrspatial.reproject` handles Dask arrays natively. It builds a lazy task graph where each output chunk is reprojected independently using numba-JIT'd resampling kernels."
159+
"`xrspatial.reproject` handles Dask arrays natively. It builds a lazy task graph where each output chunk is reprojected independently using numba-JIT'd resampling kernels.\n",
160+
"\n",
161+
"Set `chunk_size` to the desired **output** chunk size here rather than rechunking afterwards. A rechunk-merge layer after reproject creates intermediate results that pile up in cluster memory (every reproject result must be held until its merge group is complete). Writing the reproject output directly to Zarr avoids this entirely -- each result is consumed immediately."
166162
]
167163
},
168164
{
@@ -182,18 +178,11 @@
182178
" resolution=TARGET_RES,\n",
183179
" resampling=\"nearest\",\n",
184180
" nodata=np.nan,\n",
185-
" chunk_size=2048,\n",
181+
" chunk_size=4096,\n",
186182
")\n",
187183
"dem_projected"
188184
]
189185
},
190-
{
191-
"cell_type": "markdown",
192-
"metadata": {},
193-
"source": [
194-
"The result is still lazy. The repr shows the projected coordinate arrays and the new shape. No pixels have been resampled yet."
195-
]
196-
},
197186
{
198187
"cell_type": "code",
199188
"execution_count": null,
@@ -259,7 +248,17 @@
259248
"metadata": {},
260249
"outputs": [],
261250
"source": [
262-
"ds_check.xrs.preview().plot()"
251+
"small_ds = ds_check.xrs.preview()\n",
252+
"small_ds"
253+
]
254+
},
255+
{
256+
"cell_type": "code",
257+
"execution_count": null,
258+
"metadata": {},
259+
"outputs": [],
260+
"source": [
261+
"small_ds.xrs.plot()"
263262
]
264263
},
265264
{
@@ -324,4 +323,4 @@
324323
},
325324
"nbformat": 4,
326325
"nbformat_minor": 4
327-
}
326+
}

examples/user_guide/36_Rechunk_No_Shuffle.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"When working with large dask-backed rasters, rechunking to bigger blocks can\n",
1010
"speed up downstream operations like `slope()` or `focal_mean()` that use\n",
1111
"`map_overlap`. But if the new chunk size is not an exact multiple of the\n",
12-
"original, dask has to split and recombine blocks essentially a shuffle \n",
12+
"original, dask has to split and recombine blocks \u2014 essentially a shuffle \u2014\n",
1313
"which tanks performance.\n",
1414
"\n",
1515
"`rechunk_no_shuffle` picks the largest whole-chunk multiple that fits your\n",
@@ -133,7 +133,7 @@
133133
"## Non-dask arrays pass through unchanged\n",
134134
"\n",
135135
"If the input is a plain numpy-backed DataArray, the function returns it\n",
136-
"as-is no copy, no error."
136+
"as-is \u2014 no copy, no error."
137137
]
138138
},
139139
{
@@ -162,4 +162,4 @@
162162
},
163163
"nbformat": 4,
164164
"nbformat_minor": 4
165-
}
165+
}

xrspatial/accessor.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,9 @@ def plot(self, **kwargs):
6363
# Create a figure with sensible size if none provided.
6464
if 'ax' not in kwargs:
6565
fig, ax = plt.subplots(
66-
figsize=kwargs.pop('figsize', (8, 6)),
66+
figsize=kwargs.get('figsize', (8, 6)),
6767
)
68+
kwargs.pop('figsize', None)
6869
kwargs['ax'] = ax
6970

7071
result = da.plot(**kwargs)
@@ -1029,11 +1030,3 @@ def open_geotiff(self, source, **kwargs):
10291030
def rechunk_no_shuffle(self, **kwargs):
10301031
from .utils import rechunk_no_shuffle
10311032
return rechunk_no_shuffle(self._obj, **kwargs)
1032-
1033-
def fused_overlap(self, *stages, **kwargs):
1034-
from .utils import fused_overlap
1035-
return fused_overlap(self._obj, *stages, **kwargs)
1036-
1037-
def multi_overlap(self, func, n_outputs, **kwargs):
1038-
from .utils import multi_overlap
1039-
return multi_overlap(self._obj, func, n_outputs, **kwargs)

xrspatial/tests/test_rechunk_no_shuffle.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,29 @@ def test_rejects_non_dataarray():
102102
rechunk_no_shuffle(np.zeros((10, 10)))
103103

104104

105+
def test_dataset_rechunk():
106+
"""Dataset without zarr backing rechunks via the map() fallback."""
107+
ds = xr.Dataset({
108+
"elev": xr.DataArray(
109+
da.from_array(np.random.rand(100, 100).astype(np.float32),
110+
chunks=(10, 10)),
111+
dims=["y", "x"],
112+
),
113+
"slope": xr.DataArray(
114+
da.from_array(np.random.rand(100, 100).astype(np.float32),
115+
chunks=(10, 10)),
116+
dims=["y", "x"],
117+
),
118+
})
119+
result = rechunk_no_shuffle(ds, target_mb=1)
120+
assert isinstance(result, xr.Dataset)
121+
for name in ds.data_vars:
122+
xr.testing.assert_equal(ds[name], result[name])
123+
# Chunks should be at least as large as the originals.
124+
for orig, new in zip(ds[name].chunks, result[name].chunks):
125+
assert new[0] >= orig[0]
126+
127+
105128
def test_rejects_nonpositive_target():
106129
raster = _make_dask_raster()
107130
with pytest.raises(ValueError, match="target_mb must be > 0"):

xrspatial/utils.py

Lines changed: 66 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,8 +1028,35 @@ def _sample_windows_min_max(
10281028
return float(np.nanmin(np.array(mins, dtype=float))), float(np.nanmax(np.array(maxs, dtype=float)))
10291029

10301030

1031+
def _no_shuffle_chunks(chunks, dtype, dims, target_mb):
1032+
"""Compute target chunk dict that is an exact multiple of *chunks*.
1033+
1034+
Returns a ``{dim: size}`` dict, or ``None`` when the current
1035+
chunks already meet or exceed the target.
1036+
"""
1037+
base = tuple(c[0] for c in chunks)
1038+
1039+
current_bytes = dtype.itemsize
1040+
for b in base:
1041+
current_bytes *= b
1042+
1043+
target_bytes = target_mb * 1024 * 1024
1044+
1045+
if current_bytes >= target_bytes:
1046+
return None
1047+
1048+
ndim = len(base)
1049+
ratio = target_bytes / current_bytes
1050+
multiplier = max(1, int(ratio ** (1.0 / ndim)))
1051+
1052+
if multiplier <= 1:
1053+
return None
1054+
1055+
return {dim: b * multiplier for dim, b in zip(dims, base)}
1056+
1057+
10311058
def rechunk_no_shuffle(agg, target_mb=128):
1032-
"""Rechunk a dask-backed DataArray without triggering a shuffle.
1059+
"""Rechunk a dask-backed DataArray or Dataset without triggering a shuffle.
10331060
10341061
Computes an integer multiplier per dimension so that each new chunk
10351062
is an exact multiple of the original chunk size. This lets dask
@@ -1038,23 +1065,24 @@ def rechunk_no_shuffle(agg, target_mb=128):
10381065
10391066
Parameters
10401067
----------
1041-
agg : xr.DataArray
1042-
Input raster. If not backed by a dask array the input is
1043-
returned unchanged.
1068+
agg : xr.DataArray or xr.Dataset
1069+
Input raster(s). If not backed by a dask array the input is
1070+
returned unchanged. For Datasets, each variable is rechunked
1071+
independently.
10441072
target_mb : int or float
10451073
Target chunk size in megabytes. The actual chunk size will be
10461074
the closest multiple of the source chunk that does not exceed
10471075
this target. Default 128.
10481076
10491077
Returns
10501078
-------
1051-
xr.DataArray
1052-
Rechunked DataArray. Coordinates and attributes are preserved.
1079+
xr.DataArray or xr.Dataset
1080+
Rechunked object. Coordinates and attributes are preserved.
10531081
10541082
Raises
10551083
------
10561084
TypeError
1057-
If *agg* is not an ``xr.DataArray``.
1085+
If *agg* is not an ``xr.DataArray`` or ``xr.Dataset``.
10581086
ValueError
10591087
If *target_mb* is not positive.
10601088
@@ -1066,9 +1094,11 @@ def rechunk_no_shuffle(agg, target_mb=128):
10661094
>>> big = rechunk_no_shuffle(arr, target_mb=64)
10671095
>>> big.chunks # multiples of 256
10681096
"""
1097+
if isinstance(agg, xr.Dataset):
1098+
return _rechunk_dataset_no_shuffle(agg, target_mb)
10691099
if not isinstance(agg, xr.DataArray):
10701100
raise TypeError(
1071-
f"rechunk_no_shuffle(): expected xr.DataArray, "
1101+
f"rechunk_no_shuffle(): expected xr.DataArray or xr.Dataset, "
10721102
f"got {type(agg).__name__}"
10731103
)
10741104
if target_mb <= 0:
@@ -1079,27 +1109,39 @@ def rechunk_no_shuffle(agg, target_mb=128):
10791109
if not has_dask_array() or not isinstance(agg.data, da.Array):
10801110
return agg
10811111

1082-
chunks = agg.chunks # tuple of tuples
1083-
base = tuple(c[0] for c in chunks)
1084-
1085-
current_bytes = agg.dtype.itemsize
1086-
for b in base:
1087-
current_bytes *= b
1112+
new_chunks = _no_shuffle_chunks(
1113+
agg.chunks, agg.dtype, agg.dims, target_mb,
1114+
)
1115+
if new_chunks is None:
1116+
return agg
1117+
return agg.chunk(new_chunks)
10881118

1089-
target_bytes = target_mb * 1024 * 1024
10901119

1091-
if current_bytes >= target_bytes:
1092-
return agg
1120+
def _rechunk_dataset_no_shuffle(ds, target_mb):
1121+
"""Rechunk every variable in a Dataset without triggering a shuffle."""
1122+
if target_mb <= 0:
1123+
raise ValueError(
1124+
f"rechunk_no_shuffle(): target_mb must be > 0, got {target_mb}"
1125+
)
10931126

1094-
ndim = len(base)
1095-
ratio = target_bytes / current_bytes
1096-
multiplier = max(1, int(ratio ** (1.0 / ndim)))
1127+
if not has_dask_array():
1128+
return ds
1129+
1130+
# Compute target chunks from the first dask-backed variable.
1131+
# This assumes all variables share the same chunk layout and dtype;
1132+
# for mixed-dtype Datasets the budget may overshoot on smaller types.
1133+
new_chunks = None
1134+
for var in ds.data_vars.values():
1135+
if isinstance(var.data, da.Array):
1136+
new_chunks = _no_shuffle_chunks(
1137+
var.chunks, var.dtype, var.dims, target_mb,
1138+
)
1139+
break
10971140

1098-
if multiplier <= 1:
1099-
return agg
1141+
if new_chunks is None:
1142+
return ds
11001143

1101-
new_chunks = {dim: b * multiplier for dim, b in zip(agg.dims, base)}
1102-
return agg.chunk(new_chunks)
1144+
return ds.chunk(new_chunks)
11031145

11041146

11051147
def _normalize_depth(depth, ndim):

0 commit comments

Comments
 (0)