Skip to content

Commit a71302a

Browse files
committed
Remove zarr re-open logic from rechunk and preview
Drop _is_unmodified_zarr, _reopen_preview_chunks, and _preview_chunk_budget. Dataset rechunk now uses ds.chunk() instead of re-opening the zarr store under the hood.
1 parent 4c665a1 commit a71302a

File tree

3 files changed

+3
-190
lines changed

3 files changed

+3
-190
lines changed

xrspatial/preview.py

Lines changed: 0 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -13,33 +13,6 @@
1313
_COARSEN_METHODS = ('mean', 'median', 'max', 'min')
1414
_METHODS = (*_COARSEN_METHODS, 'nearest', 'bilinear')
1515

16-
# Fallback chunk budget when no distributed client is available.
17-
_DEFAULT_PREVIEW_CHUNK_BYTES = 512 * 1024 * 1024
18-
19-
20-
def _preview_chunk_budget():
21-
"""Max bytes per preview chunk, based on the active dask cluster.
22-
23-
If a ``dask.distributed`` client is connected, returns
24-
``worker_memory * 0.7 / nthreads`` so that concurrent tasks on
25-
the same worker stay under the memory-pause threshold. Otherwise
26-
falls back to ``_DEFAULT_PREVIEW_CHUNK_BYTES`` (512 MB).
27-
"""
28-
try:
29-
from dask.distributed import get_client
30-
client = get_client()
31-
info = client.scheduler_info()
32-
workers = info.get('workers', {})
33-
if workers:
34-
w = next(iter(workers.values()))
35-
mem = w.get('memory_limit', 0)
36-
nthreads = w.get('nthreads', 1) or 1
37-
if mem > 0:
38-
return int(mem * 0.7 / nthreads)
39-
except Exception:
40-
pass
41-
return _DEFAULT_PREVIEW_CHUNK_BYTES
42-
4316

4417
def _nan_full(oh, ow, block):
4518
"""NaN-filled ``(oh, ow)`` array matching *block*'s type and dtype."""
@@ -339,58 +312,6 @@ def _refine_to_target(result, target_h, target_w, y_dim, x_dim):
339312
# Public API
340313
# ---------------------------------------------------------------------------
341314

342-
def _reopen_preview_chunks(agg):
343-
"""Re-open a zarr-backed DataArray with memory-safe chunks.
344-
345-
Computes the largest chunk size that is an exact multiple of the
346-
zarr storage chunks and fits under the per-task memory budget
347-
(derived from the active dask cluster configuration). This keeps
348-
the task graph small (far fewer chunks than storage granularity)
349-
while keeping peak memory per task well within worker limits even
350-
when ``threads_per_worker > 1``.
351-
352-
Returns a new DataArray or *None* if the source isn't available.
353-
When the input is a spatial subset (``.sel()``), the returned
354-
array covers the same coordinate range.
355-
"""
356-
source = agg.encoding.get('_xrs_zarr_source')
357-
pref = agg.encoding.get('preferred_chunks')
358-
if source is None or pref is None or agg.name is None:
359-
return None
360-
try:
361-
budget = _preview_chunk_budget()
362-
# Compute the largest multiple of storage chunks that fits
363-
# under the per-task budget.
364-
base = tuple(pref[d] for d in agg.dims if d in pref)
365-
if not base or len(base) != 2:
366-
return None
367-
base_bytes = agg.dtype.itemsize * base[0] * base[1]
368-
if base_bytes >= budget:
369-
# Storage chunks already exceed the budget; use them as-is.
370-
chunks = pref
371-
else:
372-
ratio = budget / base_bytes
373-
multiplier = max(1, int(ratio ** (1.0 / len(base))))
374-
chunks = {d: pref[d] * multiplier for d in agg.dims if d in pref}
375-
376-
ds = xr.open_zarr(source, chunks=chunks)
377-
if agg.name not in ds:
378-
return None
379-
da_full = ds[agg.name]
380-
# Select to match the current DataArray's coordinate extent.
381-
sel = {}
382-
for dim in agg.dims:
383-
if dim in agg.coords and dim in da_full.coords:
384-
c = agg.coords[dim].values
385-
if len(c) > 0:
386-
sel[dim] = slice(c[0], c[-1])
387-
if sel:
388-
da_full = da_full.sel(sel)
389-
return da_full
390-
except Exception:
391-
return None
392-
393-
394315
@supports_dataset
395316
def preview(agg, width=1000, height=None, method='mean', name='preview'):
396317
"""Downsample a raster to target pixel dimensions.
@@ -433,22 +354,6 @@ def preview(agg, width=1000, height=None, method='mean', name='preview'):
433354
f"method must be one of {_METHODS!r}, got {method!r}"
434355
)
435356

436-
# If chunks are too large for a single worker task, re-open from
437-
# the zarr source with memory-safe chunks. The budget accounts
438-
# for threads_per_worker so concurrent tasks don't collectively
439-
# exceed the worker's memory-pause threshold.
440-
try:
441-
import dask.array as _da
442-
if isinstance(agg.data, _da.Array):
443-
chunk_bytes = (agg.dtype.itemsize
444-
* agg.data.chunksize[0] * agg.data.chunksize[1])
445-
if chunk_bytes > _preview_chunk_budget():
446-
safe = _reopen_preview_chunks(agg)
447-
if safe is not None:
448-
agg = safe
449-
except ImportError:
450-
pass
451-
452357
h = agg.sizes[agg.dims[0]]
453358
w = agg.sizes[agg.dims[1]]
454359

xrspatial/tests/test_rechunk_no_shuffle.py

Lines changed: 1 addition & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,6 @@
88

99
da = pytest.importorskip("dask.array")
1010

11-
_has_zarr = True
12-
try:
13-
import zarr # noqa: F401
14-
except ImportError:
15-
_has_zarr = False
16-
17-
requires_zarr = pytest.mark.skipif(not _has_zarr, reason="zarr not installed")
18-
1911

2012
# ---------------------------------------------------------------------------
2113
# Helpers
@@ -110,52 +102,7 @@ def test_rejects_non_dataarray():
110102
rechunk_no_shuffle(np.zeros((10, 10)))
111103

112104

113-
# ---------------------------------------------------------------------------
114-
# Zarr re-open optimisation
115-
# ---------------------------------------------------------------------------
116-
117-
@requires_zarr
118-
def test_zarr_reopen_reduces_graph(tmp_path):
119-
"""For a fresh zarr Dataset, rechunk should re-open with fewer tasks."""
120-
path = str(tmp_path / "rns_zarr_reopen.zarr")
121-
ds = xr.Dataset({"elev": xr.DataArray(
122-
np.random.rand(100, 100).astype(np.float64), dims=["y", "x"],
123-
coords={"y": np.arange(100), "x": np.arange(100)},
124-
)})
125-
ds.chunk({"y": 10, "x": 10}).to_zarr(path)
126-
127-
ds_in = xr.open_zarr(path)
128-
tasks_before = len(ds_in["elev"].data.__dask_graph__())
129-
130-
ds_out = rechunk_no_shuffle(ds_in, target_mb=1)
131-
tasks_after = len(ds_out["elev"].data.__dask_graph__())
132-
133-
# Re-open should produce fewer tasks, not more
134-
assert tasks_after < tasks_before, (
135-
f"expected fewer tasks after rechunk, got {tasks_after} >= {tasks_before}"
136-
)
137-
# Values must match
138-
np.testing.assert_array_equal(ds_in["elev"].values, ds_out["elev"].values)
139-
140-
141-
@requires_zarr
142-
def test_zarr_reopen_skipped_after_sel(tmp_path):
143-
"""After .sel(), the graph has >2 layers so re-open is skipped."""
144-
path = str(tmp_path / "rns_zarr_sel.zarr")
145-
ds = xr.Dataset({"elev": xr.DataArray(
146-
np.random.rand(100, 100).astype(np.float64), dims=["y", "x"],
147-
coords={"y": np.arange(100), "x": np.arange(100)},
148-
)})
149-
ds.chunk({"y": 10, "x": 10}).to_zarr(path)
150-
151-
ds_sel = xr.open_zarr(path).sel(y=slice(10, 50))
152-
result = rechunk_no_shuffle(ds_sel, target_mb=1)
153-
154-
# Should still rechunk (values match), just not via re-open
155-
np.testing.assert_array_equal(ds_sel["elev"].values, result["elev"].values)
156-
157-
158-
def test_dataset_rechunk_fallback():
105+
def test_dataset_rechunk():
159106
"""Dataset without zarr backing rechunks via the map() fallback."""
160107
ds = xr.Dataset({
161108
"elev": xr.DataArray(

xrspatial/utils.py

Lines changed: 2 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,19 +1055,6 @@ def _no_shuffle_chunks(chunks, dtype, dims, target_mb):
10551055
return {dim: b * multiplier for dim, b in zip(dims, base)}
10561056

10571057

1058-
def _is_unmodified_zarr(ds):
1059-
"""True when every dask variable is a direct zarr read (2 layers)."""
1060-
found_dask = False
1061-
for var in ds.data_vars.values():
1062-
data = var.data
1063-
if has_dask_array() and isinstance(data, da.Array):
1064-
found_dask = True
1065-
graph = data.__dask_graph__()
1066-
if hasattr(graph, 'layers') and len(graph.layers) != 2:
1067-
return False
1068-
return found_dask
1069-
1070-
10711058
def rechunk_no_shuffle(agg, target_mb=128):
10721059
"""Rechunk a dask-backed DataArray or Dataset without triggering a shuffle.
10731060
@@ -1076,12 +1063,6 @@ def rechunk_no_shuffle(agg, target_mb=128):
10761063
merge whole source chunks in-place instead of splitting and
10771064
recombining partial blocks (which is effectively a shuffle).
10781065
1079-
For file-backed data (e.g. Zarr stores), the function re-opens
1080-
the source with the target chunk sizes so that each dask task
1081-
reads multiple storage chunks in one call. This produces a
1082-
dramatically smaller task graph compared to ``.chunk()``, which
1083-
adds a rechunk merge layer on top of the existing read tasks.
1084-
10851066
Parameters
10861067
----------
10871068
agg : xr.DataArray or xr.Dataset
@@ -1137,7 +1118,7 @@ def rechunk_no_shuffle(agg, target_mb=128):
11371118

11381119

11391120
def _rechunk_dataset_no_shuffle(ds, target_mb):
1140-
"""Rechunk a Dataset, re-opening from zarr when possible."""
1121+
"""Rechunk every variable in a Dataset without triggering a shuffle."""
11411122
if target_mb <= 0:
11421123
raise ValueError(
11431124
f"rechunk_no_shuffle(): target_mb must be > 0, got {target_mb}"
@@ -1160,27 +1141,7 @@ def _rechunk_dataset_no_shuffle(ds, target_mb):
11601141
if new_chunks is None:
11611142
return ds
11621143

1163-
# For unmodified zarr reads, re-open with target chunks so
1164-
# each dask task reads multiple storage chunks in one call.
1165-
# This avoids the extra rechunk-merge graph layer that
1166-
# .chunk() would add on top of the existing read tasks.
1167-
source = ds.encoding.get('source')
1168-
if source is not None and _is_unmodified_zarr(ds):
1169-
try:
1170-
reopened = xr.open_zarr(source, chunks=new_chunks)
1171-
if set(ds.data_vars) <= set(reopened.data_vars):
1172-
result = reopened[list(ds.data_vars)]
1173-
# Propagate zarr source into each variable's encoding
1174-
# so downstream operations (e.g. preview) can re-open
1175-
# with different chunks when needed.
1176-
for name in result.data_vars:
1177-
result[name].encoding['_xrs_zarr_source'] = source
1178-
return result
1179-
except Exception:
1180-
pass
1181-
1182-
# Fallback: rechunk each variable individually.
1183-
return ds.map(rechunk_no_shuffle, target_mb=target_mb)
1144+
return ds.chunk(new_chunks)
11841145

11851146

11861147
def _normalize_depth(depth, ndim):

0 commit comments

Comments
 (0)