diff --git a/RELEASE_NOTES.rst b/RELEASE_NOTES.rst index 6ccf8625..c1d987e8 100755 --- a/RELEASE_NOTES.rst +++ b/RELEASE_NOTES.rst @@ -27,6 +27,8 @@ Upcoming Release * Fix regression when ``ExclusionContainer`` encounters a raster with an invalid CRS (https://github.com/PyPSA/atlite/pull/500). +* Fix MultiIndex specific ``FutureWarning`` in ``convert_and_aggregate`` (https://github.com/PyPSA/atlite/pull/501). + `v0.5.0 `__ (13th March 2026) ======================================================================================= diff --git a/atlite/aggregate.py b/atlite/aggregate.py index e3d3b3a4..44464c2e 100644 --- a/atlite/aggregate.py +++ b/atlite/aggregate.py @@ -5,25 +5,26 @@ Functions for aggregating results. """ -import dask +import scipy.sparse as sp import xarray as xr +from dask.array.core import Array -def aggregate_matrix(da, matrix, index): - if index.name is None: - index = index.rename("dim_0") - if isinstance(da.data, dask.array.core.Array): +def aggregate_matrix( + da: xr.DataArray, matrix: sp.csr_matrix, coords: xr.Coordinates +) -> xr.DataArray: + if isinstance(da.data, Array): da = da.stack(spatial=("y", "x")) da = da.chunk(dict(spatial=-1)) return xr.apply_ufunc( lambda da: da * matrix.T, da, input_core_dims=[["spatial"]], - output_core_dims=[[index.name]], + output_core_dims=[list(coords.dims)], dask="parallelized", output_dtypes=[da.dtype], - dask_gufunc_kwargs=dict(output_sizes={index.name: index.size}), - ).assign_coords(**{index.name: index}) + dask_gufunc_kwargs=dict(output_sizes=coords.sizes), + ).assign_coords(coords) else: da = da.stack(spatial=("y", "x")).transpose("spatial", "time") - return xr.DataArray(matrix * da, [index, da.coords["time"]]) + return xr.DataArray(matrix * da, coords.assign(time=da.coords["time"])) diff --git a/atlite/convert.py b/atlite/convert.py index 250cb785..fab5628e 100644 --- a/atlite/convert.py +++ b/atlite/convert.py @@ -40,6 +40,7 @@ get_windturbineconfig, windturbine_smooth, ) +from atlite.utils import ensure_coords logger = logging.getLogger(__name__) @@ -216,7 +217,7 @@ def convert_and_aggregate( ) if isinstance(matrix, xr.DataArray): - coords = matrix.indexes.get(matrix.dims[1]).to_frame(index=False) + coords = matrix.indexes[matrix.dims[1]].to_frame(index=False) if not np.array_equal(coords[["x", "y"]], cutout.grid[["x", "y"]]): raise ValueError( "Matrix spatial coordinates not aligned with cutout spatial " @@ -247,15 +248,17 @@ def convert_and_aggregate( else: matrix = csr_matrix(matrix) * spdiag(layout) - # From here on, matrix is defined and ensured to be a csr matrix. - if index is None: - index = pd.RangeIndex(matrix.shape[0]) + # guaranteed by code flow above, helps type checker + assert isinstance(matrix, csr_matrix) - results = aggregate_matrix(da, matrix=matrix, index=index) + coords = ensure_coords(pd.RangeIndex(matrix.shape[0]) if index is None else index) + if len(coords.dims) > 1: + raise ValueError(f"index must have a single dimension, not: {coords.dims}") + results = aggregate_matrix(da, matrix=matrix, coords=coords) if per_unit or return_capacity: caps = matrix.sum(-1) - capacity = xr.DataArray(np.asarray(caps).flatten(), [index]) + capacity = xr.DataArray(np.asarray(caps).flatten(), coords) capacity.attrs["units"] = "MW" if per_unit: diff --git a/atlite/utils.py b/atlite/utils.py index 86bd5bc3..564a7921 100644 --- a/atlite/utils.py +++ b/atlite/utils.py @@ -19,6 +19,23 @@ logger = logging.getLogger(__name__) +def ensure_coords(index: pd.Index | xr.Coordinates) -> xr.Coordinates: + """ + Convert an index or multiindex to coordinates + """ + if isinstance(index, pd.MultiIndex): + coords = xr.Coordinates.from_pandas_multiindex(index, index.name or "dim_0") + elif isinstance(index, pd.Index): + coords = xr.Coordinates({index.name or "dim_0": index}) + elif isinstance(index, xr.Coordinates): + coords = index + else: + raise ValueError( + f"index must be a pandas index or xarray coordinates, not: {index}" + ) + return coords + + def migrate_from_cutout_directory(old_cutout_dir, path): """ Convert an old style cutout directory to new style netcdf file.