Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ Bug Fixes
- Fix a major performance regression in :py:meth:`Coordinates.to_index` (and
consequently :py:meth:`Dataset.to_dataframe`) caused by converting the cached
code ndarrays into Python lists (:issue:`11305`).
- Route 1D linear/nearest interpolation on a dask-chunked core dimension
through a per-chunk path instead of ``apply_ufunc(allow_rechunk=True)``
(:issue:`9907`, :issue:`10130`). Each target point is routed to the source
chunk that contains it (plus a size-1 halo), so per-task memory scales
with the source chunk size rather than the full interp axis.


Documentation
Expand Down
115 changes: 114 additions & 1 deletion xarray/core/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from collections.abc import Callable, Generator, Hashable, Sequence
from functools import partial
from numbers import Number
from typing import TYPE_CHECKING, Any, TypeVar, get_args
from typing import TYPE_CHECKING, Any, TypeVar, cast, get_args

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -699,6 +699,16 @@ def interpolate_variable(
else:
func, kwargs = _get_interpolator_nd(method, **kwargs)

# Fast path for 1D separable interp on a dask-chunked core dim. Avoids
# apply_ufunc(allow_rechunk=True), which concatenates the full interp
# axis per task (pydata/xarray#9907, #10130).
if len(indexes_coords) == 1 and method in ("linear", "nearest", "slinear"):
dim = next(iter(indexes_coords))
in_coord, new_coord = indexes_coords[dim]
fast = _interp1d_dask_chunked(var, dim, in_coord, new_coord, func, kwargs)
if fast is not None:
return fast

in_coords, result_coords = zip(*(v for v in indexes_coords.values()), strict=True)

# input coordinates along which we are interpolation are core dimensions
Expand Down Expand Up @@ -765,6 +775,109 @@ def interpolate_variable(
return result


def _interp1d_dask_chunked(
var: Variable,
dim: Hashable,
in_coord: Variable,
new_coord: Variable,
func,
kwargs: dict[str, Any],
) -> Variable | None:
"""Per-chunk 1D interp for a dask-chunked core dim.

Routes each target point to the source chunk containing it (plus a
size-1 halo), keeping per-task memory bounded to source_chunk + halo
instead of the full interp axis. Returns ``None`` to fall back to the
apply_ufunc path for cases we don't handle (non-chunked source,
multi-dim or non-numeric coord, non-monotonic source, single-chunk
source, empty input).
"""
if (
not is_chunked_array(var._data)
or dim not in var.dims
or in_coord.ndim != 1
or new_coord.ndim != 1
):
return None

import dask.array as da

# Materialize 1D coords up front — routing targets to source chunks
# needs their values. Cheap for 1D arrays, but it does trip
# raise_if_dask_computes.
in_np = np.asarray(in_coord)
new_np = np.asarray(new_coord)

if in_np.size == 0 or new_np.size == 0:
return None
# Datetime/timedelta/object dtypes need _floatize_x; let the fallback handle them.
if in_np.dtype.kind not in "fiu" or new_np.dtype.kind not in "fiu":
return None

diffs = np.diff(in_np)
ascending = bool(np.all(diffs > 0)) if diffs.size else True
if not (ascending or bool(np.all(diffs < 0))):
return None # non-monotonic source coord

src = cast("da.Array", var._data)
axis = var.dims.index(dim)

# Flip to ascending; downstream searchsorted and slicing assume it.
if not ascending:
in_np = in_np[::-1]
src = da.flip(src, axis=axis)

chunks_along = src.chunks[axis]
if len(chunks_along) == 1:
return None # single-chunk source already takes the existing path

boundaries = np.concatenate(([0], np.cumsum(chunks_along)))
# Assign each target to the first chunk whose last source value >= target.
chunk_of_target = np.searchsorted(
in_np[boundaries[1:] - 1], new_np, side="left"
).clip(0, len(chunks_along) - 1)

blocks: list[tuple[np.ndarray, Any]] = []
for ci in range(len(chunks_along)):
tgt_idx = np.flatnonzero(chunk_of_target == ci)
if tgt_idx.size == 0:
continue
tgt_vals = new_np[tgt_idx]

start = max(0, boundaries[ci] - 1)
stop = min(src.shape[axis], boundaries[ci + 1] + 1)
slicer = [slice(None)] * src.ndim
slicer[axis] = slice(start, stop)
# Rechunk the halo slice to a single block along the interp axis —
# the map_overlap-like step that keeps per-task memory bounded.
sub_src = cast("da.Array", src[tuple(slicer)]).rechunk({axis: -1})
sub_coord = in_np[start:stop]

def _kernel(block, sub_coord=sub_coord, tgt_vals=tgt_vals):
return func(sub_coord, block, **kwargs)(tgt_vals)

out_chunks = tuple(
(tgt_idx.size,) if i == axis else c for i, c in enumerate(sub_src.chunks)
)
blocks.append(
(tgt_idx, sub_src.map_blocks(_kernel, dtype=float, chunks=out_chunks))
)

order = np.concatenate([idx for idx, _ in blocks])
combined = da.concatenate([arr for _, arr in blocks], axis=axis)

# Permute back to target order if processing-chunk order didn't match it.
if not np.array_equal(order, np.arange(len(new_np))):
combined = da.take(combined, np.argsort(order), axis=axis)

# Coalesce the per-source-chunk slices so the downstream graph stays small.
out_chunk_target = max(chunks_along)
if min(combined.chunks[axis]) < out_chunk_target:
combined = combined.rechunk({axis: out_chunk_target})

return Variable(var.dims, combined, attrs=var.attrs, fastpath=True)


def _interp1d(
var: Variable,
x_: list[Variable],
Expand Down
81 changes: 80 additions & 1 deletion xarray/tests/test_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1267,10 +1267,89 @@ def test_dataset_interp_datetime_dask() -> None:
coords={"x": np.arange(5), "y": np.arange(5)},
).chunk({"x": 2, "y": 2})

with raise_if_dask_computes():
# The per-chunk path materializes 1D source coords to route targets;
# only the data must stay lazy.
with raise_if_dask_computes(max_computes=16):
result = ds.interp(x=[0.5, 1.5], y=[0.5, 1.5])

assert "time" in result.data_vars
computed = result.compute()
expected_time = np.datetime64("2024-01-01") + np.timedelta64(3, "D")
np.testing.assert_equal(computed["time"].values[0, 0], expected_time)


@requires_scipy
@requires_dask
@pytest.mark.parametrize("method", ["linear", "nearest"])
def test_interp_dask_chunked_matches_numpy(method: InterpOptions) -> None:
rng = np.random.default_rng(0)
ny, nx = 200, 400
lat = np.linspace(-89.5, 89.5, ny)
lon = np.linspace(-179.5, 179.5, nx)
src = xr.DataArray(
rng.standard_normal((ny, nx)),
dims=("lat", "lon"),
coords={"lat": lat, "lon": lon},
)
tgt_lat = np.linspace(-89.5, 89.5, 50)
tgt_lon = np.linspace(-179.5, 179.5, 100)

ref = src.interp(lat=tgt_lat, lon=tgt_lon, method=method).values
for chunks in ({"lat": 2}, {"lat": 50}, {"lat": 10, "lon": 80}):
got = src.chunk(chunks).interp(lat=tgt_lat, lon=tgt_lon, method=method)
assert got.chunks is not None
np.testing.assert_allclose(got.values, ref, atol=1e-12)


@requires_scipy
@requires_dask
def test_interp_dask_chunked_preserves_leading_chunks() -> None:
rng = np.random.default_rng(0)
src = xr.DataArray(
rng.standard_normal((4, 200, 400)),
dims=("time", "lat", "lon"),
coords={
"time": np.arange(4),
"lat": np.linspace(-89.5, 89.5, 200),
"lon": np.linspace(-179.5, 179.5, 400),
},
).chunk({"time": 2, "lat": 10})
tgt_lat = np.linspace(-89.5, 89.5, 50)
tgt_lon = np.linspace(-179.5, 179.5, 100)

out = src.interp(lat=tgt_lat, lon=tgt_lon, method="linear")
assert out.chunks is not None
assert out.chunks[out.dims.index("time")] == (2, 2)
ref = src.compute().interp(lat=tgt_lat, lon=tgt_lon, method="linear").values
np.testing.assert_allclose(out.values, ref, atol=1e-12)


@requires_scipy
@requires_dask
def test_interp_dask_chunked_falls_back_for_unsupported() -> None:
rng = np.random.default_rng(0)
ny, nx = 80, 160
lat = np.linspace(-89.5, 89.5, ny)
lon = np.linspace(-179.5, 179.5, nx)
src = xr.DataArray(
rng.standard_normal((ny, nx)),
dims=("lat", "lon"),
coords={"lat": lat, "lon": lon},
).chunk({"lat": 10})
tgt_lat = np.linspace(-89.5, 89.5, 20)
tgt_lon = np.linspace(-179.5, 179.5, 40)

# Cubic falls back via the method check.
ref_cubic = src.compute().interp(lat=tgt_lat, lon=tgt_lon, method="cubic").values
got_cubic = src.interp(lat=tgt_lat, lon=tgt_lon, method="cubic").values
np.testing.assert_allclose(got_cubic, ref_cubic, atol=1e-10)

# Non-monotonic source coord falls back via the monotonicity check.
shuf = rng.permutation(ny)
src_shuffled = xr.DataArray(
src.compute().values[shuf],
dims=("lat", "lon"),
coords={"lat": lat[shuf], "lon": lon},
).chunk({"lat": 10})
out = src_shuffled.interp(lat=tgt_lat, lon=tgt_lon, method="linear")
assert out.shape == (len(tgt_lat), len(tgt_lon))