Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions doc/api/groupby.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,18 @@ Grouper Objects
:toctree: ../generated/

groupers.BinGrouper
groupers.UniqueGrouper
groupers.TimeResampler
groupers.SeasonGrouper
groupers.UniqueGrouper


Resampler Objects
-----------------

.. autosummary::
:toctree: ../generated/

groupers.SeasonResampler
groupers.SeasonResampler.compute_chunks

groupers.TimeResampler
groupers.TimeResampler.compute_chunks
4 changes: 3 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ v2025.08.1 (unreleased)

New Features
~~~~~~~~~~~~
- Support rechunking by :py:class:`~xarray.groupers.SeasonResampler` for seasonal data analysis (:issue:`10425`, :pull:`10519`).
By `Dhruva Kumar Kaushal <https://github.com/dhruvak001>`_.
- Add convenience methods to :py:class:`~xarray.Coordinates` (:pull:`10318`)
By `Justus Magin <https://github.com/keewis>`_.
- Added :py:func:`load_datatree` for loading ``DataTree`` objects into memory
Expand Down Expand Up @@ -157,7 +159,7 @@ Bug fixes
creates extra variables that don't match the provided coordinate names, instead
of silently ignoring them. The error message suggests using the factory method
pattern with :py:meth:`xarray.Coordinates.from_xindex` and
:py:meth:`Dataset.assign_coords` for advanced use cases (:issue:`10499`).
:py:meth:`Dataset.assign_coords` for advanced use cases (:issue:`10499`, :pull:`10503`).
By `Dhruva Kumar Kaushal <https://github.com/dhruvak001>`_.

Documentation
Expand Down
50 changes: 20 additions & 30 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2486,13 +2486,16 @@ def chunk(
sizes along that dimension will not be updated; non-dask arrays will be
converted into dask arrays with a single block.

Along datetime-like dimensions, a :py:class:`groupers.TimeResampler` object is also accepted.
Along datetime-like dimensions, a :py:class:`Resampler` object
(e.g. :py:class:`groupers.TimeResampler` or :py:class:`groupers.SeasonResampler`)
is also accepted.

Parameters
----------
chunks : int, tuple of int, "auto" or mapping of hashable to int or a TimeResampler, optional
chunks : int, tuple of int, "auto" or mapping of hashable to int or a Resampler, optional
Chunk sizes along each dimension, e.g., ``5``, ``"auto"``, or
``{"x": 5, "y": 5}`` or ``{"x": 5, "time": TimeResampler(freq="YE")}``.
``{"x": 5, "y": 5}`` or ``{"x": 5, "time": TimeResampler(freq="YE")}`` or
``{"time": SeasonResampler(["DJF", "MAM", "JJA", "SON"])}``.
name_prefix : str, default: "xarray-"
Prefix for the name of any new dask arrays.
token : str, optional
Expand Down Expand Up @@ -2527,8 +2530,7 @@ def chunk(
xarray.unify_chunks
dask.array.from_array
"""
from xarray.core.dataarray import DataArray
from xarray.groupers import TimeResampler
from xarray.groupers import Resampler

if chunks is None and not chunks_kwargs:
warnings.warn(
Expand Down Expand Up @@ -2556,41 +2558,29 @@ def chunk(
f"chunks keys {tuple(bad_dims)} not found in data dimensions {tuple(self.sizes.keys())}"
)

def _resolve_frequency(
name: Hashable, resampler: TimeResampler
) -> tuple[int, ...]:
def _resolve_resampler(name: Hashable, resampler: Resampler) -> tuple[int, ...]:
variable = self._variables.get(name, None)
if variable is None:
raise ValueError(
f"Cannot chunk by resampler {resampler!r} for virtual variables."
f"Cannot chunk by resampler {resampler!r} for virtual variable {name!r}."
)
elif not _contains_datetime_like_objects(variable):
if variable.ndim != 1:
raise ValueError(
f"chunks={resampler!r} only supported for datetime variables. "
f"Received variable {name!r} with dtype {variable.dtype!r} instead."
f"chunks={resampler!r} only supported for 1D variables. "
f"Received variable {name!r} with {variable.ndim} dimensions instead."
)

assert variable.ndim == 1
chunks = (
DataArray(
np.ones(variable.shape, dtype=int),
dims=(name,),
coords={name: variable},
newchunks = resampler.compute_chunks(variable, dim=name)
if sum(newchunks) != variable.shape[0]:
raise ValueError(
f"Logic bug in rechunking variable {name!r} using {resampler!r}. "
"New chunks tuple does not match size of data. Please open an issue."
)
.resample({name: resampler})
.sum()
)
# When bins (binning) or time periods are missing (resampling)
# we can end up with NaNs. Drop them.
if chunks.dtype.kind == "f":
chunks = chunks.dropna(name).astype(int)
chunks_tuple: tuple[int, ...] = tuple(chunks.data.tolist())
return chunks_tuple
return newchunks

chunks_mapping_ints: Mapping[Any, T_ChunkDim] = {
name: (
_resolve_frequency(name, chunks)
if isinstance(chunks, TimeResampler)
_resolve_resampler(name, chunks)
if isinstance(chunks, Resampler)
else chunks
)
for name, chunks in chunks_mapping.items()
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from xarray.core.indexes import Index, Indexes
from xarray.core.utils import Frozen
from xarray.core.variable import IndexVariable, Variable
from xarray.groupers import Grouper, TimeResampler
from xarray.groupers import Grouper, Resampler
from xarray.structure.alignment import Aligner

GroupInput: TypeAlias = (
Expand Down Expand Up @@ -201,7 +201,7 @@ def copy(
# FYI in some cases we don't allow `None`, which this doesn't take account of.
# FYI the `str` is for a size string, e.g. "16MB", supported by dask.
T_ChunkDim: TypeAlias = str | int | Literal["auto"] | tuple[int, ...] | None # noqa: PYI051
T_ChunkDimFreq: TypeAlias = Union["TimeResampler", T_ChunkDim]
T_ChunkDimFreq: TypeAlias = Union["Resampler", T_ChunkDim]
T_ChunksFreq: TypeAlias = T_ChunkDim | Mapping[Any, T_ChunkDimFreq]
# We allow the tuple form of this (though arguably we could transition to named dims only)
T_Chunks: TypeAlias = T_ChunkDim | Mapping[Any, T_ChunkDim]
Expand Down
121 changes: 119 additions & 2 deletions xarray/groupers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import operator
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Mapping, Sequence
from collections.abc import Hashable, Mapping, Sequence
from dataclasses import dataclass, field
from itertools import chain, pairwise
from typing import TYPE_CHECKING, Any, Literal, cast
Expand Down Expand Up @@ -52,6 +52,8 @@
"EncodedGroups",
"Grouper",
"Resampler",
"SeasonGrouper",
"SeasonResampler",
Comment thread
dhruvak001 marked this conversation as resolved.
"TimeResampler",
"UniqueGrouper",
]
Expand Down Expand Up @@ -169,7 +171,26 @@ class Resampler(Grouper):
Currently only used for TimeResampler, but could be used for SpaceResampler in the future.
"""

pass
def compute_chunks(self, variable: Variable, *, dim: Hashable) -> tuple[int, ...]:
"""
Compute chunk sizes for this resampler.

This method should be implemented by subclasses to provide appropriate
chunking behavior for their specific resampling strategy.

Parameters
----------
variable : Variable
The variable being chunked.
dim : Hashable
The name of the dimension being chunked.

Returns
-------
tuple[int, ...]
A tuple of chunk sizes for the dimension.
"""
raise NotImplementedError("Subclasses must implement compute_chunks method")


@dataclass
Expand Down Expand Up @@ -565,6 +586,49 @@ def factorize(self, group: T_Group) -> EncodedGroups:
coords=coordinates_from_variable(unique_coord),
)

def compute_chunks(self, variable: Variable, *, dim: Hashable) -> tuple[int, ...]:
"""
Compute chunk sizes for this time resampler.

This method is used during chunking operations to determine appropriate
chunk sizes for the given variable when using this resampler.

Parameters
----------
name : Hashable
The name of the dimension being chunked.
variable : Variable
The variable being chunked.

Returns
-------
tuple[int, ...]
A tuple of chunk sizes for the dimension.
"""
from xarray.core.dataarray import DataArray

if not _contains_datetime_like_objects(variable):
raise ValueError(
f"Computing chunks with {type(self)!r} only supported for datetime variables. "
f"Received variable with dtype {variable.dtype!r} instead."
)

chunks = (
DataArray(
np.ones(variable.shape, dtype=int),
dims=(dim,),
coords={dim: variable},
)
.resample({dim: self})
.sum()
)
# When bins (binning) or time periods are missing (resampling)
# we can end up with NaNs. Drop them.
if chunks.dtype.kind == "f":
chunks = chunks.dropna(dim).astype(int)
chunks_tuple: tuple[int, ...] = tuple(chunks.data.tolist())
return chunks_tuple


def _factorize_given_labels(data: np.ndarray, labels: np.ndarray) -> np.ndarray:
# Copied from flox
Expand Down Expand Up @@ -967,5 +1031,58 @@ def get_label(year, season):

return EncodedGroups(codes=codes, full_index=full_index)

def compute_chunks(self, variable: Variable, *, dim: Hashable) -> tuple[int, ...]:
"""
Compute chunk sizes for this season resampler.

This method is used during chunking operations to determine appropriate
chunk sizes for the given variable when using this resampler.

Parameters
----------
name : Hashable
The name of the dimension being chunked.
variable : Variable
The variable being chunked.

Returns
-------
tuple[int, ...]
A tuple of chunk sizes for the dimension.
"""
from xarray.core.dataarray import DataArray

if not _contains_datetime_like_objects(variable):
raise ValueError(
f"Computing chunks with {type(self)!r} only supported for datetime variables. "
f"Received variable with dtype {variable.dtype!r} instead."
)

if len("".join(self.seasons)) != 12:
raise ValueError(
"Cannot rechunk with a SeasonResampler that does not cover all 12 months. "
f"Received `seasons={self.seasons!r}`."
)

# Create a temporary resampler that ignores drop_incomplete for chunking
# This prevents data from being silently dropped during chunking
resampler_for_chunking = type(self)(seasons=self.seasons, drop_incomplete=False)

chunks = (
DataArray(
np.ones(variable.shape, dtype=int),
dims=(dim,),
coords={dim: variable},
)
.resample({dim: resampler_for_chunking})
.sum()
)
# When bins (binning) or time periods are missing (resampling)
# we can end up with NaNs. Drop them.
if chunks.dtype.kind == "f":
chunks = chunks.dropna(dim).astype(int)
chunks_tuple: tuple[int, ...] = tuple(chunks.data.tolist())
return chunks_tuple

def reset(self) -> Self:
return type(self)(seasons=self.seasons, drop_incomplete=self.drop_incomplete)
Loading
Loading