Skip to content

Commit 4c92d52

Browse files
headtr1ckdcherianpre-commit-ci[bot]
authored
CFTime support for polyval (#6624)
Co-authored-by: dcherian <deepak@cherian.net> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>
1 parent 95a47af commit 4c92d52

3 files changed

Lines changed: 116 additions & 14 deletions

File tree

xarray/core/computation.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from .options import OPTIONS, _get_keep_attrs
3232
from .pycompat import is_duck_dask_array
3333
from .types import T_DataArray
34-
from .utils import is_dict_like
34+
from .utils import is_dict_like, is_scalar
3535
from .variable import Variable
3636

3737
if TYPE_CHECKING:
@@ -1887,6 +1887,15 @@ def polyval(coord: Dataset, coeffs: Dataset, degree_dim: Hashable) -> Dataset:
18871887
...
18881888

18891889

1890+
@overload
1891+
def polyval(
1892+
coord: Dataset | DataArray,
1893+
coeffs: Dataset | DataArray,
1894+
degree_dim: Hashable = "degree",
1895+
) -> Dataset | DataArray:
1896+
...
1897+
1898+
18901899
def polyval(
18911900
coord: Dataset | DataArray,
18921901
coeffs: Dataset | DataArray,
@@ -1953,15 +1962,21 @@ def _ensure_numeric(data: Dataset | DataArray) -> Dataset | DataArray:
19531962
"""
19541963
from .dataset import Dataset
19551964

1965+
def _cfoffset(x: DataArray) -> Any:
1966+
scalar = x.compute().data[0]
1967+
if not is_scalar(scalar):
1968+
# we do not get a scalar back on dask == 2021.04.1
1969+
scalar = scalar.item()
1970+
return type(scalar)(1970, 1, 1)
1971+
19561972
def to_floatable(x: DataArray) -> DataArray:
1957-
if x.dtype.kind == "M":
1958-
# datetimes
1973+
if x.dtype.kind in "MO":
1974+
# datetimes (CFIndexes are object type)
1975+
offset = (
1976+
np.datetime64("1970-01-01") if x.dtype.kind == "M" else _cfoffset(x)
1977+
)
19591978
return x.copy(
1960-
data=datetime_to_numeric(
1961-
x.data,
1962-
offset=np.datetime64("1970-01-01"),
1963-
datetime_unit="ns",
1964-
),
1979+
data=datetime_to_numeric(x.data, offset=offset, datetime_unit="ns"),
19651980
)
19661981
elif x.dtype.kind == "m":
19671982
# timedeltas

xarray/core/duck_array_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ def datetime_to_numeric(array, offset=None, datetime_unit=None, dtype=float):
435435
# This map_blocks call is for backwards compatibility.
436436
# dask == 2021.04.1 does not support subtracting object arrays
437437
# which is required for cftime
438-
if is_duck_dask_array(array) and np.issubdtype(array.dtype, np.object):
438+
if is_duck_dask_array(array) and np.issubdtype(array.dtype, object):
439439
array = array.map_blocks(lambda a, b: a - b, offset, meta=array._meta)
440440
else:
441441
array = array - offset

xarray/tests/test_computation.py

Lines changed: 92 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,13 @@
2626
from xarray.core.pycompat import dask_version
2727
from xarray.core.types import T_Xarray
2828

29-
from . import has_dask, raise_if_dask_computes, requires_dask
29+
from . import (
30+
has_cftime,
31+
has_dask,
32+
raise_if_dask_computes,
33+
requires_cftime,
34+
requires_dask,
35+
)
3036

3137

3238
def assert_identical(a, b):
@@ -1936,7 +1942,9 @@ def test_where_attrs() -> None:
19361942
assert actual.attrs == {}
19371943

19381944

1939-
@pytest.mark.parametrize("use_dask", [False, True])
1945+
@pytest.mark.parametrize(
1946+
"use_dask", [pytest.param(False, id="nodask"), pytest.param(True, id="dask")]
1947+
)
19401948
@pytest.mark.parametrize(
19411949
["x", "coeffs", "expected"],
19421950
[
@@ -2031,20 +2039,99 @@ def test_polyval(
20312039
pytest.skip("requires dask")
20322040
coeffs = coeffs.chunk({"degree": 2})
20332041
x = x.chunk({"x": 2})
2042+
20342043
with raise_if_dask_computes():
2035-
actual = xr.polyval(coord=x, coeffs=coeffs) # type: ignore
2044+
actual = xr.polyval(coord=x, coeffs=coeffs)
2045+
2046+
xr.testing.assert_allclose(actual, expected)
2047+
2048+
2049+
@requires_cftime
2050+
@pytest.mark.parametrize(
2051+
"use_dask", [pytest.param(False, id="nodask"), pytest.param(True, id="dask")]
2052+
)
2053+
@pytest.mark.parametrize("date", ["1970-01-01", "0753-04-21"])
2054+
def test_polyval_cftime(use_dask: bool, date: str) -> None:
2055+
import cftime
2056+
2057+
x = xr.DataArray(
2058+
xr.date_range(date, freq="1S", periods=3, use_cftime=True),
2059+
dims="x",
2060+
)
2061+
coeffs = xr.DataArray([0, 1], dims="degree", coords={"degree": [0, 1]})
2062+
2063+
if use_dask:
2064+
if not has_dask:
2065+
pytest.skip("requires dask")
2066+
coeffs = coeffs.chunk({"degree": 2})
2067+
x = x.chunk({"x": 2})
2068+
2069+
with raise_if_dask_computes(max_computes=1):
2070+
actual = xr.polyval(coord=x, coeffs=coeffs)
2071+
2072+
t0 = xr.date_range(date, periods=1)[0]
2073+
offset = (t0 - cftime.DatetimeGregorian(1970, 1, 1)).total_seconds() * 1e9
2074+
expected = (
2075+
xr.DataArray(
2076+
[0, 1e9, 2e9],
2077+
dims="x",
2078+
coords={"x": xr.date_range(date, freq="1S", periods=3, use_cftime=True)},
2079+
)
2080+
+ offset
2081+
)
20362082
xr.testing.assert_allclose(actual, expected)
20372083

20382084

2039-
def test_polyval_degree_dim_checks():
2040-
x = (xr.DataArray([1, 2, 3], dims="x"),)
2085+
def test_polyval_degree_dim_checks() -> None:
2086+
x = xr.DataArray([1, 2, 3], dims="x")
20412087
coeffs = xr.DataArray([2, 3, 4], dims="degree", coords={"degree": [0, 1, 2]})
20422088
with pytest.raises(ValueError):
20432089
xr.polyval(x, coeffs.drop_vars("degree"))
20442090
with pytest.raises(ValueError):
20452091
xr.polyval(x, coeffs.assign_coords(degree=coeffs.degree.astype(float)))
20462092

20472093

2094+
@pytest.mark.parametrize(
2095+
"use_dask", [pytest.param(False, id="nodask"), pytest.param(True, id="dask")]
2096+
)
2097+
@pytest.mark.parametrize(
2098+
"x",
2099+
[
2100+
pytest.param(xr.DataArray([0, 1, 2], dims="x"), id="simple"),
2101+
pytest.param(
2102+
xr.DataArray(pd.date_range("1970-01-01", freq="ns", periods=3), dims="x"),
2103+
id="datetime",
2104+
),
2105+
pytest.param(
2106+
xr.DataArray(np.array([0, 1, 2], dtype="timedelta64[ns]"), dims="x"),
2107+
id="timedelta",
2108+
),
2109+
],
2110+
)
2111+
@pytest.mark.parametrize(
2112+
"y",
2113+
[
2114+
pytest.param(xr.DataArray([1, 6, 17], dims="x"), id="1D"),
2115+
pytest.param(
2116+
xr.DataArray([[1, 6, 17], [34, 57, 86]], dims=("y", "x")), id="2D"
2117+
),
2118+
],
2119+
)
2120+
def test_polyfit_polyval_integration(
2121+
use_dask: bool, x: xr.DataArray, y: xr.DataArray
2122+
) -> None:
2123+
y.coords["x"] = x
2124+
if use_dask:
2125+
if not has_dask:
2126+
pytest.skip("requires dask")
2127+
y = y.chunk({"x": 2})
2128+
2129+
fit = y.polyfit(dim="x", deg=2)
2130+
evaluated = xr.polyval(y.x, fit.polyfit_coefficients)
2131+
expected = y.transpose(*evaluated.dims)
2132+
xr.testing.assert_allclose(evaluated.variable, expected.variable)
2133+
2134+
20482135
@pytest.mark.parametrize("use_dask", [False, True])
20492136
@pytest.mark.parametrize(
20502137
"a, b, ae, be, dim, axis",

0 commit comments

Comments
 (0)