Skip to content

Commit 6258962

Browse files
committed
Add xr.Dataset support to rechunk_no_shuffle (#1069)
Accepts both DataArray and Dataset. For Datasets, each dask-backed variable is rechunked independently. Also adds the method to the Dataset .xrs accessor.
1 parent 0531d16 commit 6258962

File tree

2 files changed

+48
-30
lines changed

2 files changed

+48
-30
lines changed

xrspatial/accessor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -910,3 +910,9 @@ def open_geotiff(self, source, **kwargs):
910910
y_min, y_max, x_min, x_max)
911911
kwargs.pop('window', None)
912912
return open_geotiff(source, window=window, **kwargs)
913+
914+
# ---- Chunking ----
915+
916+
def rechunk_no_shuffle(self, **kwargs):
917+
from .utils import rechunk_no_shuffle
918+
return rechunk_no_shuffle(self._obj, **kwargs)

xrspatial/utils.py

Lines changed: 42 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,8 +1028,34 @@ def _sample_windows_min_max(
10281028
return float(np.nanmin(np.array(mins, dtype=float))), float(np.nanmax(np.array(maxs, dtype=float)))
10291029

10301030

1031+
def _rechunk_dataarray(agg, target_bytes):
1032+
"""Rechunk a single dask-backed DataArray. Returns unchanged if not dask."""
1033+
if not has_dask_array() or not isinstance(agg.data, da.Array):
1034+
return agg
1035+
1036+
chunks = agg.chunks # tuple of tuples
1037+
base = tuple(c[0] for c in chunks)
1038+
1039+
current_bytes = agg.dtype.itemsize
1040+
for b in base:
1041+
current_bytes *= b
1042+
1043+
if current_bytes >= target_bytes:
1044+
return agg
1045+
1046+
ndim = len(base)
1047+
ratio = target_bytes / current_bytes
1048+
multiplier = max(1, int(ratio ** (1.0 / ndim)))
1049+
1050+
if multiplier <= 1:
1051+
return agg
1052+
1053+
new_chunks = {dim: b * multiplier for dim, b in zip(agg.dims, base)}
1054+
return agg.chunk(new_chunks)
1055+
1056+
10311057
def rechunk_no_shuffle(agg, target_mb=128):
1032-
"""Rechunk a dask-backed DataArray without triggering a shuffle.
1058+
"""Rechunk dask-backed data without triggering a shuffle.
10331059
10341060
Computes an integer multiplier per dimension so that each new chunk
10351061
is an exact multiple of the original chunk size. This lets dask
@@ -1038,23 +1064,23 @@ def rechunk_no_shuffle(agg, target_mb=128):
10381064
10391065
Parameters
10401066
----------
1041-
agg : xr.DataArray
1042-
Input raster. If not backed by a dask array the input is
1043-
returned unchanged.
1067+
agg : xr.DataArray or xr.Dataset
1068+
Input raster or collection of rasters. Non-dask variables
1069+
pass through unchanged.
10441070
target_mb : int or float
10451071
Target chunk size in megabytes. The actual chunk size will be
10461072
the closest multiple of the source chunk that does not exceed
10471073
this target. Default 128.
10481074
10491075
Returns
10501076
-------
1051-
xr.DataArray
1052-
Rechunked DataArray. Coordinates and attributes are preserved.
1077+
xr.DataArray or xr.Dataset
1078+
Rechunked object. Coordinates and attributes are preserved.
10531079
10541080
Raises
10551081
------
10561082
TypeError
1057-
If *agg* is not an ``xr.DataArray``.
1083+
If *agg* is not an ``xr.DataArray`` or ``xr.Dataset``.
10581084
ValueError
10591085
If *target_mb* is not positive.
10601086
@@ -1066,37 +1092,23 @@ def rechunk_no_shuffle(agg, target_mb=128):
10661092
>>> big = rechunk_no_shuffle(arr, target_mb=64)
10671093
>>> big.chunks # multiples of 256
10681094
"""
1069-
if not isinstance(agg, xr.DataArray):
1095+
if not isinstance(agg, (xr.DataArray, xr.Dataset)):
10701096
raise TypeError(
1071-
f"rechunk_no_shuffle(): expected xr.DataArray, "
1097+
f"rechunk_no_shuffle(): expected xr.DataArray or xr.Dataset, "
10721098
f"got {type(agg).__name__}"
10731099
)
10741100
if target_mb <= 0:
10751101
raise ValueError(
10761102
f"rechunk_no_shuffle(): target_mb must be > 0, got {target_mb}"
10771103
)
10781104

1079-
if not has_dask_array() or not isinstance(agg.data, da.Array):
1080-
return agg
1081-
1082-
chunks = agg.chunks # tuple of tuples
1083-
base = tuple(c[0] for c in chunks)
1084-
1085-
current_bytes = agg.dtype.itemsize
1086-
for b in base:
1087-
current_bytes *= b
1088-
10891105
target_bytes = target_mb * 1024 * 1024
10901106

1091-
if current_bytes >= target_bytes:
1092-
return agg
1093-
1094-
ndim = len(base)
1095-
ratio = target_bytes / current_bytes
1096-
multiplier = max(1, int(ratio ** (1.0 / ndim)))
1107+
if isinstance(agg, xr.DataArray):
1108+
return _rechunk_dataarray(agg, target_bytes)
10971109

1098-
if multiplier <= 1:
1099-
return agg
1100-
1101-
new_chunks = {dim: b * multiplier for dim, b in zip(agg.dims, base)}
1102-
return agg.chunk(new_chunks)
1110+
# Dataset: rechunk each variable independently
1111+
new_vars = {}
1112+
for name, var in agg.data_vars.items():
1113+
new_vars[name] = _rechunk_dataarray(var, target_bytes)
1114+
return agg.assign(new_vars)

0 commit comments

Comments
 (0)