|
10 | 10 | import operator |
11 | 11 | import os |
12 | 12 | from collections.abc import Generator, Hashable, Iterable, Sequence |
13 | | -from functools import reduce, wraps |
| 13 | +from functools import partial, reduce, wraps |
14 | 14 | from pathlib import Path |
15 | 15 | from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, overload |
16 | 16 | from warnings import warn |
|
19 | 19 | import pandas as pd |
20 | 20 | import polars as pl |
21 | 21 | from numpy import arange, signedinteger |
22 | | -from xarray import DataArray, Dataset, align, apply_ufunc, broadcast |
23 | | -from xarray.core import indexing |
| 22 | +from pandas.util._decorators import doc |
| 23 | +from xarray import DataArray, Dataset, apply_ufunc, broadcast |
| 24 | +from xarray import align as xr_align |
| 25 | +from xarray.core import dtypes, indexing |
| 26 | +from xarray.core.types import JoinOptions, T_Alignable |
24 | 27 | from xarray.namedarray.utils import is_dict_like |
25 | 28 |
|
26 | 29 | from linopy.config import options |
@@ -426,13 +429,13 @@ def save_join(*dataarrays: DataArray, integer_dtype: bool = False) -> Dataset: |
426 | 429 | Join multiple xarray Dataarray's to a Dataset and warn if coordinates are not equal. |
427 | 430 | """ |
428 | 431 | try: |
429 | | - arrs = align(*dataarrays, join="exact") |
| 432 | + arrs = xr_align(*dataarrays, join="exact") |
430 | 433 | except ValueError: |
431 | 434 | warn( |
432 | 435 | "Coordinates across variables not equal. Perform outer join.", |
433 | 436 | UserWarning, |
434 | 437 | ) |
435 | | - arrs = align(*dataarrays, join="outer") |
| 438 | + arrs = xr_align(*dataarrays, join="outer") |
436 | 439 | if integer_dtype: |
437 | 440 | arrs = tuple([ds.fillna(-1).astype(int) for ds in arrs]) |
438 | 441 | return Dataset({ds.name: ds for ds in arrs}) |
@@ -987,6 +990,50 @@ def check_common_keys_values(list_of_dicts: list[dict[str, Any]]) -> bool: |
987 | 990 | return all(len({d[k] for d in list_of_dicts if k in d}) == 1 for k in common_keys) |
988 | 991 |
|
989 | 992 |
|
| 993 | +@doc(xr_align) |
| 994 | +def align( |
| 995 | + *objects: LinearExpression | Variable | T_Alignable, |
| 996 | + join: JoinOptions = "inner", |
| 997 | + copy: bool = True, |
| 998 | + indexes=None, |
| 999 | + exclude: str | Iterable[Hashable] = frozenset(), |
| 1000 | + fill_value=dtypes.NA, |
| 1001 | +) -> tuple[LinearExpression | Variable | T_Alignable, ...]: |
| 1002 | + from linopy.expressions import LinearExpression |
| 1003 | + from linopy.variables import Variable |
| 1004 | + |
| 1005 | + finisher = [] |
| 1006 | + das = [] |
| 1007 | + for obj in objects: |
| 1008 | + if isinstance(obj, LinearExpression): |
| 1009 | + finisher.append(partial(obj.__class__, model=obj.model)) |
| 1010 | + das.append(obj.data) |
| 1011 | + elif isinstance(obj, Variable): |
| 1012 | + finisher.append( |
| 1013 | + partial( |
| 1014 | + obj.__class__, |
| 1015 | + model=obj.model, |
| 1016 | + name=obj.data.attrs["name"], |
| 1017 | + skip_broadcast=True, |
| 1018 | + ) |
| 1019 | + ) |
| 1020 | + das.append(obj.data) |
| 1021 | + else: |
| 1022 | + finisher.append(lambda x: x) |
| 1023 | + das.append(obj) |
| 1024 | + |
| 1025 | + exclude = frozenset(exclude).union(HELPER_DIMS) |
| 1026 | + aligned = xr_align( |
| 1027 | + *das, |
| 1028 | + join=join, |
| 1029 | + copy=copy, |
| 1030 | + indexes=indexes, |
| 1031 | + exclude=exclude, |
| 1032 | + fill_value=fill_value, |
| 1033 | + ) |
| 1034 | + return tuple([f(da) for f, da in zip(finisher, aligned)]) |
| 1035 | + |
| 1036 | + |
990 | 1037 | LocT = TypeVar("LocT", "Dataset", "Variable", "LinearExpression", "Constraint") |
991 | 1038 |
|
992 | 1039 |
|
|
0 commit comments