From 5c538674b68495d69c5f07c575a789a882f297d6 Mon Sep 17 00:00:00 2001 From: Jonas Hoersch Date: Wed, 12 Mar 2025 11:24:38 +0100 Subject: [PATCH 1/2] feat: add align method wrapping xr.align --- linopy/__init__.py | 2 ++ linopy/common.py | 57 +++++++++++++++++++++++++++++++++++++++++---- test/test_common.py | 44 ++++++++++++++++++++++++++++++++++ 3 files changed, 98 insertions(+), 5 deletions(-) diff --git a/linopy/__init__.py b/linopy/__init__.py index a789c027..88ee5251 100644 --- a/linopy/__init__.py +++ b/linopy/__init__.py @@ -12,6 +12,7 @@ # Note: For intercepting multiplications between xarray dataarrays, Variables and Expressions # we need to extend their __mul__ functions with a quick special case import linopy.monkey_patch_xarray # noqa: F401 +from linopy.common import align from linopy.config import options from linopy.constants import EQUAL, GREATER_EQUAL, LESS_EQUAL from linopy.constraints import Constraint, Constraints @@ -35,6 +36,7 @@ "Variable", "Variables", "available_solvers", + "align", "merge", "options", "read_netcdf", diff --git a/linopy/common.py b/linopy/common.py index 0dd66d39..668b7731 100644 --- a/linopy/common.py +++ b/linopy/common.py @@ -10,7 +10,7 @@ import operator import os from collections.abc import Generator, Hashable, Iterable, Sequence -from functools import reduce, wraps +from functools import partial, reduce, wraps from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, overload from warnings import warn @@ -19,8 +19,11 @@ import pandas as pd import polars as pl from numpy import arange, signedinteger -from xarray import DataArray, Dataset, align, apply_ufunc, broadcast -from xarray.core import indexing +from pandas.util._decorators import doc +from xarray import DataArray, Dataset, apply_ufunc, broadcast +from xarray import align as xr_align +from xarray.core import dtypes, indexing +from xarray.core.types import JoinOptions, T_Alignable from xarray.namedarray.utils import is_dict_like from linopy.config import options @@ -426,13 +429,13 @@ def save_join(*dataarrays: DataArray, integer_dtype: bool = False) -> Dataset: Join multiple xarray Dataarray's to a Dataset and warn if coordinates are not equal. """ try: - arrs = align(*dataarrays, join="exact") + arrs = xr_align(*dataarrays, join="exact") except ValueError: warn( "Coordinates across variables not equal. Perform outer join.", UserWarning, ) - arrs = align(*dataarrays, join="outer") + arrs = xr_align(*dataarrays, join="outer") if integer_dtype: arrs = tuple([ds.fillna(-1).astype(int) for ds in arrs]) return Dataset({ds.name: ds for ds in arrs}) @@ -987,6 +990,50 @@ def check_common_keys_values(list_of_dicts: list[dict[str, Any]]) -> bool: return all(len({d[k] for d in list_of_dicts if k in d}) == 1 for k in common_keys) +@doc(xr_align) +def align( + *objects: LinearExpression | Variable | T_Alignable, + join: JoinOptions = "inner", + copy: bool = True, + indexes=None, + exclude: str | Iterable[Hashable] = frozenset(), + fill_value=dtypes.NA, +) -> tuple[LinearExpression | Variable | T_Alignable, ...]: + from linopy.expressions import LinearExpression + from linopy.variables import Variable + + finisher = [] + das = [] + for obj in objects: + if isinstance(obj, LinearExpression): + finisher.append(partial(obj.__class__, model=obj.model)) + das.append(obj.data) + elif isinstance(obj, Variable): + finisher.append( + partial( + obj.__class__, + model=obj.model, + name=obj.data.attrs["name"], + skip_broadcast=True, + ) + ) + das.append(obj.data) + else: + finisher.append(lambda x: x) + das.append(obj) + + exclude = frozenset(exclude).union(HELPER_DIMS) + aligned = xr_align( + *das, + join=join, + copy=copy, + indexes=indexes, + exclude=exclude, + fill_value=fill_value, + ) + return tuple([f(da) for f, da in zip(finisher, aligned)]) + + LocT = TypeVar("LocT", "Dataset", "Variable", "LinearExpression", "Constraint") diff --git a/test/test_common.py b/test/test_common.py index 565a6d2f..8f0da7a3 100644 --- a/test/test_common.py +++ b/test/test_common.py @@ -9,15 +9,20 @@ import pandas as pd import pytest import xarray as xr +from test_linear_expression import m, u, x # noqa: F401 from xarray import DataArray +from xarray.testing.assertions import assert_equal +from linopy import Model, Variable from linopy.common import ( + align, as_dataarray, assign_multiindex_safe, best_int, get_dims_with_index_levels, iterate_slices, ) +from linopy.testing import assert_linequal, assert_varequal def test_as_dataarray_with_series_dims_default() -> None: @@ -644,3 +649,42 @@ def test_get_dims_with_index_levels() -> None: # Test case 5: Empty dataset ds5 = xr.Dataset() assert get_dims_with_index_levels(ds5) == [] + + +def test_align(m: Model, x: Variable, u: Variable) -> None: + alpha = xr.DataArray([1, 2], [[1, 2]]) + beta = xr.DataArray( + [1, 2, 3], + [ + ( + "dim_3", + pd.MultiIndex.from_tuples( + [(1, "b"), (2, "b"), (1, "c")], names=["level1", "level2"] + ), + ) + ], + ) + + # inner join + x_obs, alpha_obs = align(x, alpha) + assert x_obs.shape == alpha_obs.shape == (1,) + assert_varequal(x_obs, x.loc[[1]]) + + # left-join + x_obs, alpha_obs = align(x, alpha, join="left") + assert x_obs.shape == alpha_obs.shape == (2,) + assert_varequal(x_obs, x) + assert_equal(alpha_obs, DataArray([np.nan, 1], [[0, 1]])) + + # multiindex + beta_obs, u_obs = align(beta, u) + assert u_obs.shape == beta_obs.shape == (2,) + assert_varequal(u_obs, u.loc[[(1, "b"), (2, "b")]]) + assert_equal(beta_obs, beta.loc[[(1, "b"), (2, "b")]]) + + # with linear expression + expr = 20 * x + x_obs, expr_obs, alpha_obs = align(x, expr, alpha) + assert x_obs.shape == alpha_obs.shape == (1,) + assert expr_obs.shape == (1, 1) # _term dim + assert_linequal(expr_obs, expr.loc[[1]]) From c4501fb6ff3675cfb93c230cb9fcd6a337be23d4 Mon Sep 17 00:00:00 2001 From: Jonas Hoersch Date: Wed, 12 Mar 2025 12:27:14 +0100 Subject: [PATCH 2/2] add release notes --- doc/release_notes.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/doc/release_notes.rst b/doc/release_notes.rst index ac86ad97..039aac46 100644 --- a/doc/release_notes.rst +++ b/doc/release_notes.rst @@ -4,7 +4,10 @@ Release Notes Upcoming Version ---------------- +** Features ** + * Added support for arithmetic operations with custom classes. +* Added `align` function as a wrapper around :func:`xr.align`. Version 0.5.0 --------------