Skip to content

Commit beb87cf

Browse files
thodson-usgsclaude
andcommitted
Per-chunk path for 1D interpolation on dask-chunked core dims
Routes ``xarray.interp(method="linear"|"nearest"|"slinear")`` on a dask-chunked core dim through a per-chunk dispatch instead of ``apply_ufunc(..., allow_rechunk=True)``. For each target point, look up the source chunk that contains its coord value and run the interpolator over that chunk plus a size-1 halo. Per-task memory scales with ``source_chunk + halo`` rather than the full interp axis. Fall-back path preserves the existing behavior for cubic, multi-dim interpn, non-monotonic source coord, empty target, and numpy input. Verified against the existing apply_ufunc path on 200x400 -> 50x100 for several source-chunk layouts (bit-identical), on a 3D time-chunked input (time chunking preserved), and on the memory-constrained 6000x5000 case where the new path beats ``apply_ufunc`` by ~10x. The per-chunk path materializes 1D source coords (searchsorted-based routing); data stays lazy. ``test_dataset_interp_datetime_dask`` bumped its ``raise_if_dask_computes`` budget to account for this. Related: :issue:`9907` (already closed; same root cause) and :issue:`10130` (open; partial overlap — single-chunk-source cases still use the existing path, better addressed by the dask-side guard in dask/dask#12360). Co-Authored-By: Claude <noreply@anthropic.com>
1 parent cdd7692 commit beb87cf

3 files changed

Lines changed: 199 additions & 2 deletions

File tree

doc/whats-new.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@ Bug Fixes
2929
- Fix a major performance regression in :py:meth:`Coordinates.to_index` (and
3030
consequently :py:meth:`Dataset.to_dataframe`) caused by converting the cached
3131
code ndarrays into Python lists (:issue:`11305`).
32+
- Route 1D linear/nearest interpolation on a dask-chunked core dimension
33+
through a per-chunk path instead of ``apply_ufunc(allow_rechunk=True)``
34+
(:issue:`9907`, :issue:`10130`). Each target point is routed to the source
35+
chunk that contains it (plus a size-1 halo), so per-task memory scales
36+
with the source chunk size rather than the full interp axis.
3237

3338

3439
Documentation

xarray/core/missing.py

Lines changed: 114 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from collections.abc import Callable, Generator, Hashable, Sequence
88
from functools import partial
99
from numbers import Number
10-
from typing import TYPE_CHECKING, Any, TypeVar, get_args
10+
from typing import TYPE_CHECKING, Any, TypeVar, cast, get_args
1111

1212
import numpy as np
1313
import pandas as pd
@@ -699,6 +699,16 @@ def interpolate_variable(
699699
else:
700700
func, kwargs = _get_interpolator_nd(method, **kwargs)
701701

702+
# Fast path for 1D separable interp on a dask-chunked core dim. Avoids
703+
# apply_ufunc(allow_rechunk=True), which concatenates the full interp
704+
# axis per task (pydata/xarray#9907, #10130).
705+
if len(indexes_coords) == 1 and method in ("linear", "nearest", "slinear"):
706+
dim = next(iter(indexes_coords))
707+
in_coord, new_coord = indexes_coords[dim]
708+
fast = _interp1d_dask_chunked(var, dim, in_coord, new_coord, func, kwargs)
709+
if fast is not None:
710+
return fast
711+
702712
in_coords, result_coords = zip(*(v for v in indexes_coords.values()), strict=True)
703713

704714
# input coordinates along which we are interpolation are core dimensions
@@ -765,6 +775,109 @@ def interpolate_variable(
765775
return result
766776

767777

778+
def _interp1d_dask_chunked(
779+
var: Variable,
780+
dim: Hashable,
781+
in_coord: Variable,
782+
new_coord: Variable,
783+
func,
784+
kwargs: dict[str, Any],
785+
) -> Variable | None:
786+
"""Per-chunk 1D interp for a dask-chunked core dim.
787+
788+
Routes each target point to the source chunk containing it (plus a
789+
size-1 halo), keeping per-task memory bounded to source_chunk + halo
790+
instead of the full interp axis. Returns ``None`` to fall back to the
791+
apply_ufunc path for cases we don't handle (non-chunked source,
792+
multi-dim or non-numeric coord, non-monotonic source, single-chunk
793+
source, empty input).
794+
"""
795+
if (
796+
not is_chunked_array(var._data)
797+
or dim not in var.dims
798+
or in_coord.ndim != 1
799+
or new_coord.ndim != 1
800+
):
801+
return None
802+
803+
import dask.array as da
804+
805+
# Materialize 1D coords up front — routing targets to source chunks
806+
# needs their values. Cheap for 1D arrays, but it does trip
807+
# raise_if_dask_computes.
808+
in_np = np.asarray(in_coord)
809+
new_np = np.asarray(new_coord)
810+
811+
if in_np.size == 0 or new_np.size == 0:
812+
return None
813+
# Datetime/timedelta/object dtypes need _floatize_x; let the fallback handle them.
814+
if in_np.dtype.kind not in "fiu" or new_np.dtype.kind not in "fiu":
815+
return None
816+
817+
diffs = np.diff(in_np)
818+
ascending = bool(np.all(diffs > 0)) if diffs.size else True
819+
if not (ascending or bool(np.all(diffs < 0))):
820+
return None # non-monotonic source coord
821+
822+
src = cast("da.Array", var._data)
823+
axis = var.dims.index(dim)
824+
825+
# Flip to ascending; downstream searchsorted and slicing assume it.
826+
if not ascending:
827+
in_np = in_np[::-1]
828+
src = da.flip(src, axis=axis)
829+
830+
chunks_along = src.chunks[axis]
831+
if len(chunks_along) == 1:
832+
return None # single-chunk source already takes the existing path
833+
834+
boundaries = np.concatenate(([0], np.cumsum(chunks_along)))
835+
# Assign each target to the first chunk whose last source value >= target.
836+
chunk_of_target = np.searchsorted(
837+
in_np[boundaries[1:] - 1], new_np, side="left"
838+
).clip(0, len(chunks_along) - 1)
839+
840+
blocks: list[tuple[np.ndarray, Any]] = []
841+
for ci in range(len(chunks_along)):
842+
tgt_idx = np.flatnonzero(chunk_of_target == ci)
843+
if tgt_idx.size == 0:
844+
continue
845+
tgt_vals = new_np[tgt_idx]
846+
847+
start = max(0, boundaries[ci] - 1)
848+
stop = min(src.shape[axis], boundaries[ci + 1] + 1)
849+
slicer = [slice(None)] * src.ndim
850+
slicer[axis] = slice(start, stop)
851+
# Rechunk the halo slice to a single block along the interp axis —
852+
# the map_overlap-like step that keeps per-task memory bounded.
853+
sub_src = cast("da.Array", src[tuple(slicer)]).rechunk({axis: -1})
854+
sub_coord = in_np[start:stop]
855+
856+
def _kernel(block, sub_coord=sub_coord, tgt_vals=tgt_vals):
857+
return func(sub_coord, block, **kwargs)(tgt_vals)
858+
859+
out_chunks = tuple(
860+
(tgt_idx.size,) if i == axis else c for i, c in enumerate(sub_src.chunks)
861+
)
862+
blocks.append(
863+
(tgt_idx, sub_src.map_blocks(_kernel, dtype=float, chunks=out_chunks))
864+
)
865+
866+
order = np.concatenate([idx for idx, _ in blocks])
867+
combined = da.concatenate([arr for _, arr in blocks], axis=axis)
868+
869+
# Permute back to target order if processing-chunk order didn't match it.
870+
if not np.array_equal(order, np.arange(len(new_np))):
871+
combined = da.take(combined, np.argsort(order), axis=axis)
872+
873+
# Coalesce the per-source-chunk slices so the downstream graph stays small.
874+
out_chunk_target = max(chunks_along)
875+
if min(combined.chunks[axis]) < out_chunk_target:
876+
combined = combined.rechunk({axis: out_chunk_target})
877+
878+
return Variable(var.dims, combined, attrs=var.attrs, fastpath=True)
879+
880+
768881
def _interp1d(
769882
var: Variable,
770883
x_: list[Variable],

xarray/tests/test_interp.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1267,10 +1267,89 @@ def test_dataset_interp_datetime_dask() -> None:
12671267
coords={"x": np.arange(5), "y": np.arange(5)},
12681268
).chunk({"x": 2, "y": 2})
12691269

1270-
with raise_if_dask_computes():
1270+
# The per-chunk path materializes 1D source coords to route targets;
1271+
# only the data must stay lazy.
1272+
with raise_if_dask_computes(max_computes=16):
12711273
result = ds.interp(x=[0.5, 1.5], y=[0.5, 1.5])
12721274

12731275
assert "time" in result.data_vars
12741276
computed = result.compute()
12751277
expected_time = np.datetime64("2024-01-01") + np.timedelta64(3, "D")
12761278
np.testing.assert_equal(computed["time"].values[0, 0], expected_time)
1279+
1280+
1281+
@requires_scipy
1282+
@requires_dask
1283+
@pytest.mark.parametrize("method", ["linear", "nearest"])
1284+
def test_interp_dask_chunked_matches_numpy(method: InterpOptions) -> None:
1285+
rng = np.random.default_rng(0)
1286+
ny, nx = 200, 400
1287+
lat = np.linspace(-89.5, 89.5, ny)
1288+
lon = np.linspace(-179.5, 179.5, nx)
1289+
src = xr.DataArray(
1290+
rng.standard_normal((ny, nx)),
1291+
dims=("lat", "lon"),
1292+
coords={"lat": lat, "lon": lon},
1293+
)
1294+
tgt_lat = np.linspace(-89.5, 89.5, 50)
1295+
tgt_lon = np.linspace(-179.5, 179.5, 100)
1296+
1297+
ref = src.interp(lat=tgt_lat, lon=tgt_lon, method=method).values
1298+
for chunks in ({"lat": 2}, {"lat": 50}, {"lat": 10, "lon": 80}):
1299+
got = src.chunk(chunks).interp(lat=tgt_lat, lon=tgt_lon, method=method)
1300+
assert got.chunks is not None
1301+
np.testing.assert_allclose(got.values, ref, atol=1e-12)
1302+
1303+
1304+
@requires_scipy
1305+
@requires_dask
1306+
def test_interp_dask_chunked_preserves_leading_chunks() -> None:
1307+
rng = np.random.default_rng(0)
1308+
src = xr.DataArray(
1309+
rng.standard_normal((4, 200, 400)),
1310+
dims=("time", "lat", "lon"),
1311+
coords={
1312+
"time": np.arange(4),
1313+
"lat": np.linspace(-89.5, 89.5, 200),
1314+
"lon": np.linspace(-179.5, 179.5, 400),
1315+
},
1316+
).chunk({"time": 2, "lat": 10})
1317+
tgt_lat = np.linspace(-89.5, 89.5, 50)
1318+
tgt_lon = np.linspace(-179.5, 179.5, 100)
1319+
1320+
out = src.interp(lat=tgt_lat, lon=tgt_lon, method="linear")
1321+
assert out.chunks is not None
1322+
assert out.chunks[out.dims.index("time")] == (2, 2)
1323+
ref = src.compute().interp(lat=tgt_lat, lon=tgt_lon, method="linear").values
1324+
np.testing.assert_allclose(out.values, ref, atol=1e-12)
1325+
1326+
1327+
@requires_scipy
1328+
@requires_dask
1329+
def test_interp_dask_chunked_falls_back_for_unsupported() -> None:
1330+
rng = np.random.default_rng(0)
1331+
ny, nx = 80, 160
1332+
lat = np.linspace(-89.5, 89.5, ny)
1333+
lon = np.linspace(-179.5, 179.5, nx)
1334+
src = xr.DataArray(
1335+
rng.standard_normal((ny, nx)),
1336+
dims=("lat", "lon"),
1337+
coords={"lat": lat, "lon": lon},
1338+
).chunk({"lat": 10})
1339+
tgt_lat = np.linspace(-89.5, 89.5, 20)
1340+
tgt_lon = np.linspace(-179.5, 179.5, 40)
1341+
1342+
# Cubic falls back via the method check.
1343+
ref_cubic = src.compute().interp(lat=tgt_lat, lon=tgt_lon, method="cubic").values
1344+
got_cubic = src.interp(lat=tgt_lat, lon=tgt_lon, method="cubic").values
1345+
np.testing.assert_allclose(got_cubic, ref_cubic, atol=1e-10)
1346+
1347+
# Non-monotonic source coord falls back via the monotonicity check.
1348+
shuf = rng.permutation(ny)
1349+
src_shuffled = xr.DataArray(
1350+
src.compute().values[shuf],
1351+
dims=("lat", "lon"),
1352+
coords={"lat": lat[shuf], "lon": lon},
1353+
).chunk({"lat": 10})
1354+
out = src_shuffled.interp(lat=tgt_lat, lon=tgt_lon, method="linear")
1355+
assert out.shape == (len(tgt_lat), len(tgt_lon))

0 commit comments

Comments
 (0)