Skip to content

Commit 639f126

Browse files
coroaFabianHofmann
andauthored
feat: add method wrapper for xr.align (#426)
* feat: add align method wrapping xr.align * add release notes --------- Co-authored-by: Fabian Hofmann <fab.hof@gmx.de>
1 parent 6fdf678 commit 639f126

File tree

4 files changed

+101
-5
lines changed

4 files changed

+101
-5
lines changed

doc/release_notes.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@ Upcoming Version
1111

1212
**Features**
1313

14+
** Features **
15+
1416
* Added support for arithmetic operations with custom classes.
17+
* Added `align` function as a wrapper around :func:`xr.align`.
1518
* Avoid allocating a floating license for COPT during the initial solver check
1619

1720
Version 0.5.0

linopy/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# Note: For intercepting multiplications between xarray dataarrays, Variables and Expressions
1313
# we need to extend their __mul__ functions with a quick special case
1414
import linopy.monkey_patch_xarray # noqa: F401
15+
from linopy.common import align
1516
from linopy.config import options
1617
from linopy.constants import EQUAL, GREATER_EQUAL, LESS_EQUAL
1718
from linopy.constraints import Constraint, Constraints
@@ -35,6 +36,7 @@
3536
"Variable",
3637
"Variables",
3738
"available_solvers",
39+
"align",
3840
"merge",
3941
"options",
4042
"read_netcdf",

linopy/common.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import operator
1111
import os
1212
from collections.abc import Generator, Hashable, Iterable, Sequence
13-
from functools import reduce, wraps
13+
from functools import partial, reduce, wraps
1414
from pathlib import Path
1515
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, overload
1616
from warnings import warn
@@ -19,8 +19,11 @@
1919
import pandas as pd
2020
import polars as pl
2121
from numpy import arange, signedinteger
22-
from xarray import DataArray, Dataset, align, apply_ufunc, broadcast
23-
from xarray.core import indexing
22+
from pandas.util._decorators import doc
23+
from xarray import DataArray, Dataset, apply_ufunc, broadcast
24+
from xarray import align as xr_align
25+
from xarray.core import dtypes, indexing
26+
from xarray.core.types import JoinOptions, T_Alignable
2427
from xarray.namedarray.utils import is_dict_like
2528

2629
from linopy.config import options
@@ -426,13 +429,13 @@ def save_join(*dataarrays: DataArray, integer_dtype: bool = False) -> Dataset:
426429
Join multiple xarray Dataarray's to a Dataset and warn if coordinates are not equal.
427430
"""
428431
try:
429-
arrs = align(*dataarrays, join="exact")
432+
arrs = xr_align(*dataarrays, join="exact")
430433
except ValueError:
431434
warn(
432435
"Coordinates across variables not equal. Perform outer join.",
433436
UserWarning,
434437
)
435-
arrs = align(*dataarrays, join="outer")
438+
arrs = xr_align(*dataarrays, join="outer")
436439
if integer_dtype:
437440
arrs = tuple([ds.fillna(-1).astype(int) for ds in arrs])
438441
return Dataset({ds.name: ds for ds in arrs})
@@ -987,6 +990,50 @@ def check_common_keys_values(list_of_dicts: list[dict[str, Any]]) -> bool:
987990
return all(len({d[k] for d in list_of_dicts if k in d}) == 1 for k in common_keys)
988991

989992

993+
@doc(xr_align)
994+
def align(
995+
*objects: LinearExpression | Variable | T_Alignable,
996+
join: JoinOptions = "inner",
997+
copy: bool = True,
998+
indexes=None,
999+
exclude: str | Iterable[Hashable] = frozenset(),
1000+
fill_value=dtypes.NA,
1001+
) -> tuple[LinearExpression | Variable | T_Alignable, ...]:
1002+
from linopy.expressions import LinearExpression
1003+
from linopy.variables import Variable
1004+
1005+
finisher = []
1006+
das = []
1007+
for obj in objects:
1008+
if isinstance(obj, LinearExpression):
1009+
finisher.append(partial(obj.__class__, model=obj.model))
1010+
das.append(obj.data)
1011+
elif isinstance(obj, Variable):
1012+
finisher.append(
1013+
partial(
1014+
obj.__class__,
1015+
model=obj.model,
1016+
name=obj.data.attrs["name"],
1017+
skip_broadcast=True,
1018+
)
1019+
)
1020+
das.append(obj.data)
1021+
else:
1022+
finisher.append(lambda x: x)
1023+
das.append(obj)
1024+
1025+
exclude = frozenset(exclude).union(HELPER_DIMS)
1026+
aligned = xr_align(
1027+
*das,
1028+
join=join,
1029+
copy=copy,
1030+
indexes=indexes,
1031+
exclude=exclude,
1032+
fill_value=fill_value,
1033+
)
1034+
return tuple([f(da) for f, da in zip(finisher, aligned)])
1035+
1036+
9901037
LocT = TypeVar("LocT", "Dataset", "Variable", "LinearExpression", "Constraint")
9911038

9921039

test/test_common.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,20 @@
99
import pandas as pd
1010
import pytest
1111
import xarray as xr
12+
from test_linear_expression import m, u, x # noqa: F401
1213
from xarray import DataArray
14+
from xarray.testing.assertions import assert_equal
1315

16+
from linopy import Model, Variable
1417
from linopy.common import (
18+
align,
1519
as_dataarray,
1620
assign_multiindex_safe,
1721
best_int,
1822
get_dims_with_index_levels,
1923
iterate_slices,
2024
)
25+
from linopy.testing import assert_linequal, assert_varequal
2126

2227

2328
def test_as_dataarray_with_series_dims_default() -> None:
@@ -644,3 +649,42 @@ def test_get_dims_with_index_levels() -> None:
644649
# Test case 5: Empty dataset
645650
ds5 = xr.Dataset()
646651
assert get_dims_with_index_levels(ds5) == []
652+
653+
654+
def test_align(m: Model, x: Variable, u: Variable) -> None:
655+
alpha = xr.DataArray([1, 2], [[1, 2]])
656+
beta = xr.DataArray(
657+
[1, 2, 3],
658+
[
659+
(
660+
"dim_3",
661+
pd.MultiIndex.from_tuples(
662+
[(1, "b"), (2, "b"), (1, "c")], names=["level1", "level2"]
663+
),
664+
)
665+
],
666+
)
667+
668+
# inner join
669+
x_obs, alpha_obs = align(x, alpha)
670+
assert x_obs.shape == alpha_obs.shape == (1,)
671+
assert_varequal(x_obs, x.loc[[1]])
672+
673+
# left-join
674+
x_obs, alpha_obs = align(x, alpha, join="left")
675+
assert x_obs.shape == alpha_obs.shape == (2,)
676+
assert_varequal(x_obs, x)
677+
assert_equal(alpha_obs, DataArray([np.nan, 1], [[0, 1]]))
678+
679+
# multiindex
680+
beta_obs, u_obs = align(beta, u)
681+
assert u_obs.shape == beta_obs.shape == (2,)
682+
assert_varequal(u_obs, u.loc[[(1, "b"), (2, "b")]])
683+
assert_equal(beta_obs, beta.loc[[(1, "b"), (2, "b")]])
684+
685+
# with linear expression
686+
expr = 20 * x
687+
x_obs, expr_obs, alpha_obs = align(x, expr, alpha)
688+
assert x_obs.shape == alpha_obs.shape == (1,)
689+
assert expr_obs.shape == (1, 1) # _term dim
690+
assert_linequal(expr_obs, expr.loc[[1]])

0 commit comments

Comments
 (0)