Skip to content

Commit 430f0f0

Browse files
committed
fix: xarray multiindex warning
1 parent b94aeac commit 430f0f0

3 files changed

Lines changed: 38 additions & 15 deletions

File tree

atlite/aggregate.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,26 @@
55
Functions for aggregating results.
66
"""
77

8-
import dask
8+
import scipy.sparse as sp
99
import xarray as xr
10+
from dask.array.core import Array
1011

1112

12-
def aggregate_matrix(da, matrix, index):
13-
if index.name is None:
14-
index = index.rename("dim_0")
15-
if isinstance(da.data, dask.array.core.Array):
13+
def aggregate_matrix(
14+
da: xr.DataArray, matrix: sp.csr_matrix, coords: xr.Coordinates
15+
) -> xr.DataArray:
16+
if isinstance(da.data, Array):
1617
da = da.stack(spatial=("y", "x"))
1718
da = da.chunk(dict(spatial=-1))
1819
return xr.apply_ufunc(
1920
lambda da: da * matrix.T,
2021
da,
2122
input_core_dims=[["spatial"]],
22-
output_core_dims=[[index.name]],
23+
output_core_dims=[list(coords.dims)],
2324
dask="parallelized",
2425
output_dtypes=[da.dtype],
25-
dask_gufunc_kwargs=dict(output_sizes={index.name: index.size}),
26-
).assign_coords(**{index.name: index})
26+
dask_gufunc_kwargs=dict(output_sizes=coords.sizes),
27+
).assign_coords(coords)
2728
else:
2829
da = da.stack(spatial=("y", "x")).transpose("spatial", "time")
29-
return xr.DataArray(matrix * da, [index, da.coords["time"]])
30+
return xr.DataArray(matrix * da, coords.assign(time=da.coords["time"]))

atlite/convert.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
get_windturbineconfig,
4141
windturbine_smooth,
4242
)
43+
from atlite.utils import ensure_coords
4344

4445
logger = logging.getLogger(__name__)
4546

@@ -216,7 +217,7 @@ def convert_and_aggregate(
216217
)
217218

218219
if isinstance(matrix, xr.DataArray):
219-
coords = matrix.indexes.get(matrix.dims[1]).to_frame(index=False)
220+
coords = matrix.indexes[matrix.dims[1]].to_frame(index=False)
220221
if not np.array_equal(coords[["x", "y"]], cutout.grid[["x", "y"]]):
221222
raise ValueError(
222223
"Matrix spatial coordinates not aligned with cutout spatial "
@@ -247,15 +248,17 @@ def convert_and_aggregate(
247248
else:
248249
matrix = csr_matrix(matrix) * spdiag(layout)
249250

250-
# From here on, matrix is defined and ensured to be a csr matrix.
251-
if index is None:
252-
index = pd.RangeIndex(matrix.shape[0])
251+
# guaranteed by code flow above, helps type checker
252+
assert isinstance(matrix, csr_matrix)
253253

254-
results = aggregate_matrix(da, matrix=matrix, index=index)
254+
coords = ensure_coords(index)
255+
if len(coords.dims) > 1:
256+
raise ValueError(f"index must have a single dimension, not: {coords.dims}")
257+
results = aggregate_matrix(da, matrix=matrix, coords=coords)
255258

256259
if per_unit or return_capacity:
257260
caps = matrix.sum(-1)
258-
capacity = xr.DataArray(np.asarray(caps).flatten(), [index])
261+
capacity = xr.DataArray(np.asarray(caps).flatten(), coords)
259262
capacity.attrs["units"] = "MW"
260263

261264
if per_unit:

atlite/utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,25 @@
1919
logger = logging.getLogger(__name__)
2020

2121

22+
def ensure_coords(index: pd.Index | xr.Coordinates | None) -> xr.Coordinates:
23+
"""
24+
Convert an index or multiindex to coordinates
25+
"""
26+
if index is None:
27+
coords = xr.Coordinates({"dim_0": pd.RangeIndex(matrix.shape[0])})
28+
elif isinstance(index, pd.MultiIndex):
29+
coords = xr.Coordinates.from_pandas_multiindex(index, index.name or "dim_0")
30+
elif isinstance(index, pd.Index):
31+
coords = xr.Coordinates({index.name or "dim_0": index})
32+
elif isinstance(index, xr.Coordinates):
33+
coords = index
34+
else:
35+
raise ValueError(
36+
f"index must be a pandas index or xarray coordinates, not: {index}"
37+
)
38+
return coords
39+
40+
2241
def migrate_from_cutout_directory(old_cutout_dir, path):
2342
"""
2443
Convert an old style cutout directory to new style netcdf file.

0 commit comments

Comments
 (0)