Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ Upcoming Version

**Features**

** Features **

* Added support for arithmetic operations with custom classes.
* Added `align` function as a wrapper around :func:`xr.align`.
* Avoid allocating a floating license for COPT during the initial solver check

Version 0.5.0
Expand Down
2 changes: 2 additions & 0 deletions linopy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# Note: For intercepting multiplications between xarray dataarrays, Variables and Expressions
# we need to extend their __mul__ functions with a quick special case
import linopy.monkey_patch_xarray # noqa: F401
from linopy.common import align
from linopy.config import options
from linopy.constants import EQUAL, GREATER_EQUAL, LESS_EQUAL
from linopy.constraints import Constraint, Constraints
Expand All @@ -35,6 +36,7 @@
"Variable",
"Variables",
"available_solvers",
"align",
"merge",
"options",
"read_netcdf",
Expand Down
57 changes: 52 additions & 5 deletions linopy/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import operator
import os
from collections.abc import Generator, Hashable, Iterable, Sequence
from functools import reduce, wraps
from functools import partial, reduce, wraps
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, overload
from warnings import warn
Expand All @@ -19,8 +19,11 @@
import pandas as pd
import polars as pl
from numpy import arange, signedinteger
from xarray import DataArray, Dataset, align, apply_ufunc, broadcast
from xarray.core import indexing
from pandas.util._decorators import doc
from xarray import DataArray, Dataset, apply_ufunc, broadcast
from xarray import align as xr_align
from xarray.core import dtypes, indexing
from xarray.core.types import JoinOptions, T_Alignable
from xarray.namedarray.utils import is_dict_like

from linopy.config import options
Expand Down Expand Up @@ -426,13 +429,13 @@ def save_join(*dataarrays: DataArray, integer_dtype: bool = False) -> Dataset:
Join multiple xarray Dataarray's to a Dataset and warn if coordinates are not equal.
"""
try:
arrs = align(*dataarrays, join="exact")
arrs = xr_align(*dataarrays, join="exact")
except ValueError:
warn(
"Coordinates across variables not equal. Perform outer join.",
UserWarning,
)
arrs = align(*dataarrays, join="outer")
arrs = xr_align(*dataarrays, join="outer")
if integer_dtype:
arrs = tuple([ds.fillna(-1).astype(int) for ds in arrs])
return Dataset({ds.name: ds for ds in arrs})
Expand Down Expand Up @@ -987,6 +990,50 @@ def check_common_keys_values(list_of_dicts: list[dict[str, Any]]) -> bool:
return all(len({d[k] for d in list_of_dicts if k in d}) == 1 for k in common_keys)


@doc(xr_align)
def align(
*objects: LinearExpression | Variable | T_Alignable,
join: JoinOptions = "inner",
copy: bool = True,
indexes=None,
exclude: str | Iterable[Hashable] = frozenset(),
fill_value=dtypes.NA,
) -> tuple[LinearExpression | Variable | T_Alignable, ...]:
from linopy.expressions import LinearExpression
from linopy.variables import Variable

finisher = []
das = []
for obj in objects:
if isinstance(obj, LinearExpression):
finisher.append(partial(obj.__class__, model=obj.model))
das.append(obj.data)
elif isinstance(obj, Variable):
finisher.append(
partial(
obj.__class__,
model=obj.model,
name=obj.data.attrs["name"],
skip_broadcast=True,
)
)
das.append(obj.data)
else:
finisher.append(lambda x: x)
das.append(obj)

exclude = frozenset(exclude).union(HELPER_DIMS)
aligned = xr_align(
*das,
join=join,
copy=copy,
indexes=indexes,
exclude=exclude,
fill_value=fill_value,
)
return tuple([f(da) for f, da in zip(finisher, aligned)])


LocT = TypeVar("LocT", "Dataset", "Variable", "LinearExpression", "Constraint")


Expand Down
44 changes: 44 additions & 0 deletions test/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,20 @@
import pandas as pd
import pytest
import xarray as xr
from test_linear_expression import m, u, x # noqa: F401
from xarray import DataArray
from xarray.testing.assertions import assert_equal

from linopy import Model, Variable
from linopy.common import (
align,
as_dataarray,
assign_multiindex_safe,
best_int,
get_dims_with_index_levels,
iterate_slices,
)
from linopy.testing import assert_linequal, assert_varequal


def test_as_dataarray_with_series_dims_default() -> None:
Expand Down Expand Up @@ -644,3 +649,42 @@ def test_get_dims_with_index_levels() -> None:
# Test case 5: Empty dataset
ds5 = xr.Dataset()
assert get_dims_with_index_levels(ds5) == []


def test_align(m: Model, x: Variable, u: Variable) -> None:
alpha = xr.DataArray([1, 2], [[1, 2]])
beta = xr.DataArray(
[1, 2, 3],
[
(
"dim_3",
pd.MultiIndex.from_tuples(
[(1, "b"), (2, "b"), (1, "c")], names=["level1", "level2"]
),
)
],
)

# inner join
x_obs, alpha_obs = align(x, alpha)
assert x_obs.shape == alpha_obs.shape == (1,)
assert_varequal(x_obs, x.loc[[1]])

# left-join
x_obs, alpha_obs = align(x, alpha, join="left")
assert x_obs.shape == alpha_obs.shape == (2,)
assert_varequal(x_obs, x)
assert_equal(alpha_obs, DataArray([np.nan, 1], [[0, 1]]))

# multiindex
beta_obs, u_obs = align(beta, u)
assert u_obs.shape == beta_obs.shape == (2,)
assert_varequal(u_obs, u.loc[[(1, "b"), (2, "b")]])
assert_equal(beta_obs, beta.loc[[(1, "b"), (2, "b")]])

# with linear expression
expr = 20 * x
x_obs, expr_obs, alpha_obs = align(x, expr, alpha)
assert x_obs.shape == alpha_obs.shape == (1,)
assert expr_obs.shape == (1, 1) # _term dim
assert_linequal(expr_obs, expr.loc[[1]])
Loading