|
7 | 7 | from collections.abc import Callable, Generator, Hashable, Sequence |
8 | 8 | from functools import partial |
9 | 9 | from numbers import Number |
10 | | -from typing import TYPE_CHECKING, Any, TypeVar, get_args |
| 10 | +from typing import TYPE_CHECKING, Any, TypeVar, cast, get_args |
11 | 11 |
|
12 | 12 | import numpy as np |
13 | 13 | import pandas as pd |
@@ -699,6 +699,27 @@ def interpolate_variable( |
699 | 699 | else: |
700 | 700 | func, kwargs = _get_interpolator_nd(method, **kwargs) |
701 | 701 |
|
| 702 | + # Fast path for 1D separable interp on a dask-chunked core dim. Avoids |
| 703 | + # apply_ufunc(allow_rechunk=True) — the concat-then-apply dance that |
| 704 | + # blows up task graphs (pydata/xarray#9907, #10130). Each target point |
| 705 | + # is routed to the source chunk that contains it (plus a halo), so |
| 706 | + # per-task memory scales with source_chunk + halo, not the full axis. |
| 707 | + if ( |
| 708 | + len(indexes_coords) == 1 |
| 709 | + and method in ("linear", "nearest", "slinear") |
| 710 | + and is_chunked_array(var._data) |
| 711 | + ): |
| 712 | + dim = next(iter(indexes_coords)) |
| 713 | + in_coord, new_coord = indexes_coords[dim] |
| 714 | + if ( |
| 715 | + getattr(in_coord, "ndim", 1) == 1 |
| 716 | + and getattr(new_coord, "ndim", 1) == 1 |
| 717 | + and dim in var.dims |
| 718 | + ): |
| 719 | + fast = _interp1d_dask_chunked(var, dim, in_coord, new_coord, func, kwargs) |
| 720 | + if fast is not None: |
| 721 | + return fast |
| 722 | + |
702 | 723 | in_coords, result_coords = zip(*(v for v in indexes_coords.values()), strict=True) |
703 | 724 |
|
704 | 725 | # input coordinates along which we are interpolation are core dimensions |
@@ -765,6 +786,143 @@ def interpolate_variable( |
765 | 786 | return result |
766 | 787 |
|
767 | 788 |
|
| 789 | +def _interp1d_dask_chunked( |
| 790 | + var: Variable, |
| 791 | + dim: Hashable, |
| 792 | + in_coord: Variable, |
| 793 | + new_coord: Variable, |
| 794 | + func, |
| 795 | + kwargs: dict[str, Any], |
| 796 | +) -> Variable | None: |
| 797 | + """Apply separable 1D interp to a dask-chunked Variable without |
| 798 | + rechunking the core dim. |
| 799 | +
|
| 800 | + Routes each target point to the source chunk that contains its coord |
| 801 | + value, slices that chunk plus a size-1 halo, and runs the interpolator |
| 802 | + per-chunk. Output chunks along ``dim`` follow the distribution of |
| 803 | + target points across source chunks; leading/trailing dims keep the |
| 804 | + input chunking. |
| 805 | +
|
| 806 | + Returns ``None`` to signal a fall-back (caller should use the existing |
| 807 | + apply_ufunc path). Fall-back cases: empty target/source, non-monotonic |
| 808 | + source coord, or source with a single chunk along ``dim``. |
| 809 | + """ |
| 810 | + import dask.array as da |
| 811 | + |
| 812 | + # Caller guarantees var._data is chunked (is_chunked_array check). |
| 813 | + src = cast(da.Array, var._data) |
| 814 | + axis = var.dims.index(dim) |
| 815 | + |
| 816 | + # Materialize the 1D coords. If they're lazy, this forces a compute — |
| 817 | + # which is cheap for 1D coord arrays but trips strict |
| 818 | + # ``raise_if_dask_computes`` assertions. The alternative (building a |
| 819 | + # fully-lazy per-chunk graph without knowing coord values) would |
| 820 | + # require routing logic inside the compute, which defeats the point. |
| 821 | + # Small cost here buys a vastly better task graph. |
| 822 | + in_np = np.asarray(in_coord) |
| 823 | + new_np = np.asarray(new_coord) |
| 824 | + |
| 825 | + if in_np.size == 0 or new_np.size == 0: |
| 826 | + return None |
| 827 | + # Datetime / timedelta / object coords: the apply_ufunc path converts |
| 828 | + # these to float64 via ``_floatize_x`` before handing to scipy. Fall |
| 829 | + # back rather than duplicating that plumbing here. |
| 830 | + if in_np.dtype.kind not in "fiu" or new_np.dtype.kind not in "fiu": |
| 831 | + return None |
| 832 | + if in_np.size > 1 and not ( |
| 833 | + bool(np.all(in_np[1:] > in_np[:-1])) or bool(np.all(in_np[1:] < in_np[:-1])) |
| 834 | + ): |
| 835 | + return None # unsorted source coord — fall back |
| 836 | + |
| 837 | + # Work with ascending source coord. Reversing both ``in_np`` and ``src`` |
| 838 | + # along the core dim produces the same interp result as reversing the |
| 839 | + # order of searchsorted buckets — so no further compensation is needed |
| 840 | + # at the end. |
| 841 | + if in_np[0] > in_np[-1]: |
| 842 | + in_np = in_np[::-1] |
| 843 | + src = da.flip(src, axis=axis) |
| 844 | + |
| 845 | + chunks_along = src.chunks[axis] |
| 846 | + if len(chunks_along) == 1: |
| 847 | + return None # already one chunk — existing fast path handles it |
| 848 | + |
| 849 | + boundaries = np.concatenate(([0], np.cumsum(chunks_along))) |
| 850 | + |
| 851 | + # Route each target point to a source chunk: use searchsorted on the |
| 852 | + # values at the chunk-end positions. A target point at x gets assigned |
| 853 | + # to the first chunk whose end is >= x. |
| 854 | + chunk_end_vals = in_np[boundaries[1:] - 1] |
| 855 | + chunk_of_target = np.searchsorted(chunk_end_vals, new_np, side="left") |
| 856 | + chunk_of_target = np.clip(chunk_of_target, 0, len(chunks_along) - 1) |
| 857 | + |
| 858 | + # Build one block per source chunk; concat in target order. |
| 859 | + blocks: list[tuple[np.ndarray, da.Array]] = [] |
| 860 | + |
| 861 | + for ci in range(len(chunks_along)): |
| 862 | + mask = chunk_of_target == ci |
| 863 | + if not mask.any(): |
| 864 | + continue |
| 865 | + tgt_idx = np.where(mask)[0] |
| 866 | + tgt_vals = new_np[tgt_idx] |
| 867 | + |
| 868 | + halo_start = max(0, int(boundaries[ci]) - 1) |
| 869 | + halo_end = min(int(src.shape[axis]), int(boundaries[ci + 1]) + 1) |
| 870 | + |
| 871 | + slicer = tuple( |
| 872 | + slice(halo_start, halo_end) if i == axis else slice(None) |
| 873 | + for i in range(src.ndim) |
| 874 | + ) |
| 875 | + # Halo ranges straddle the chunk boundary by construction, so |
| 876 | + # rechunk this tiny slice to a single block along the interp axis. |
| 877 | + # This is the key "map_overlap"-like step — only the local halo |
| 878 | + # gets materialized per task, not the full axis. |
| 879 | + sub_src = cast(da.Array, src[slicer]).rechunk({axis: -1}) |
| 880 | + sub_coord = in_np[halo_start:halo_end] |
| 881 | + |
| 882 | + # Per-chunk kernel: scipy 1D interp applied along `axis`. |
| 883 | + def _kernel(block, sub_coord=sub_coord, tgt_vals=tgt_vals, axis=axis): |
| 884 | + return func(sub_coord, block, **kwargs)(tgt_vals) |
| 885 | + |
| 886 | + out_chunks = tuple( |
| 887 | + (len(tgt_vals),) if i == axis else c for i, c in enumerate(sub_src.chunks) |
| 888 | + ) |
| 889 | + sub_out = sub_src.map_blocks(_kernel, dtype=float, chunks=out_chunks) |
| 890 | + blocks.append((tgt_idx, sub_out)) |
| 891 | + |
| 892 | + if not blocks: |
| 893 | + # No target points land in any chunk — shouldn't happen given |
| 894 | + # the clip above, but fall back just in case. |
| 895 | + return None |
| 896 | + |
| 897 | + # Concatenate in ascending chunk order, then gather back into target order. |
| 898 | + # If target coord is monotonic (ascending), this is already in the right |
| 899 | + # order within each chunk and tgt_idx values concatenate to np.arange. |
| 900 | + order = np.concatenate([tgt for tgt, _ in blocks]) |
| 901 | + combined = da.concatenate([arr for _, arr in blocks], axis=axis) |
| 902 | + |
| 903 | + if not np.array_equal(order, np.arange(len(new_np))): |
| 904 | + # Need to permute along axis to restore target order. |
| 905 | + inv = np.argsort(order) |
| 906 | + # da doesn't support int-array fancy indexing along a single axis |
| 907 | + # cleanly for arbitrary-D; use take which does. |
| 908 | + combined = da.take(combined, inv, axis=axis) |
| 909 | + |
| 910 | + # Coalesce the target-axis chunks. Per-source-chunk emission creates |
| 911 | + # many tiny pieces (one per source chunk with any target point); |
| 912 | + # re-chunking to approximately the source axis's max chunk keeps the |
| 913 | + # output graph size reasonable without materializing anything. |
| 914 | + out_chunk_target = max(chunks_along) |
| 915 | + if any(c < out_chunk_target for c in combined.chunks[axis]): |
| 916 | + new_chunks = {axis: out_chunk_target} |
| 917 | + combined = combined.rechunk(new_chunks) |
| 918 | + |
| 919 | + # The target order in `combined` already reflects ``new_np`` in its input |
| 920 | + # order — any flip of the source coord was absorbed when we reversed |
| 921 | + # ``in_np`` and ``src`` at the top. |
| 922 | + |
| 923 | + return Variable(var.dims, combined, attrs=var.attrs, fastpath=True) |
| 924 | + |
| 925 | + |
768 | 926 | def _interp1d( |
769 | 927 | var: Variable, |
770 | 928 | x_: list[Variable], |
|
0 commit comments