From a2041970b1adb754a67696b533c489dfdd75bf74 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Tue, 10 Mar 2026 08:22:48 +0100 Subject: [PATCH 1/9] Add legacy arithmetic join mode with deprecation warning for transition - Add `options["arithmetic_join"]` setting (default: "legacy") to control coordinate alignment in arithmetic operations, merge, and constraints - Legacy mode reproduces old behavior: override when shapes match, outer otherwise for merge; reindex_like for constants; inner for align() - All legacy codepaths emit FutureWarning guiding users to opt in to "exact" - Move shared test fixtures (m, x, y, z, v, u) to conftest.py - Exact-behavior tests use autouse fixture to set arithmetic_join="exact" - Legacy test files (test_*_legacy.py) validate old behavior is preserved - All 2736 tests pass Co-Authored-By: Claude Opus 4.6 --- linopy/common.py | 16 +- linopy/config.py | 33 +- linopy/expressions.py | 94 +- test/conftest.py | 50 + test/test_algebraic_properties.py | 9 + test/test_common.py | 13 +- test/test_common_legacy.py | 735 +++++++++ test/test_constraints.py | 10 + test/test_constraints_legacy.py | 448 ++++++ test/test_linear_expression.py | 44 +- test/test_linear_expression_legacy.py | 2102 +++++++++++++++++++++++++ test/test_typing.py | 9 + test/test_typing_legacy.py | 25 + 13 files changed, 3524 insertions(+), 64 deletions(-) create mode 100644 test/test_common_legacy.py create mode 100644 test/test_constraints_legacy.py create mode 100644 test/test_linear_expression_legacy.py create mode 100644 test/test_typing_legacy.py diff --git a/linopy/common.py b/linopy/common.py index 21f851df..ea1b46d9 100644 --- a/linopy/common.py +++ b/linopy/common.py @@ -1207,7 +1207,7 @@ def deco(cls: Any) -> Any: def align( *objects: LinearExpression | QuadraticExpression | Variable | T_Alignable, - join: JoinOptions = "exact", + join: JoinOptions | None = None, copy: bool = True, indexes: Any = None, exclude: str | Iterable[Hashable] = frozenset(), @@ -1267,9 +1267,23 @@ def align( """ + from linopy.config import options from linopy.expressions import LinearExpression, QuadraticExpression from linopy.variables import Variable + if join is None: + join = options["arithmetic_join"] + + if join == "legacy": + warn( + "The 'legacy' arithmetic join is deprecated and will be removed " + "in a future version. Set linopy.options['arithmetic_join'] = " + "'exact' to opt in to the new behavior.", + FutureWarning, + stacklevel=2, + ) + join = "inner" + # Extract underlying Datasets for index computation. das: list[Any] = [] for obj in objects: diff --git a/linopy/config.py b/linopy/config.py index c098709d..c5637ce2 100644 --- a/linopy/config.py +++ b/linopy/config.py @@ -9,28 +9,43 @@ from typing import Any +VALID_ARITHMETIC_JOINS = { + "exact", + "inner", + "outer", + "left", + "right", + "override", + "legacy", +} + class OptionSettings: - def __init__(self, **kwargs: int) -> None: + def __init__(self, **kwargs: Any) -> None: self._defaults = kwargs self._current_values = kwargs.copy() - def __call__(self, **kwargs: int) -> None: + def __call__(self, **kwargs: Any) -> None: self.set_value(**kwargs) - def __getitem__(self, key: str) -> int: + def __getitem__(self, key: str) -> Any: return self.get_value(key) - def __setitem__(self, key: str, value: int) -> None: + def __setitem__(self, key: str, value: Any) -> None: return self.set_value(**{key: value}) - def set_value(self, **kwargs: int) -> None: + def set_value(self, **kwargs: Any) -> None: for k, v in kwargs.items(): if k not in self._defaults: raise KeyError(f"{k} is not a valid setting.") + if k == "arithmetic_join" and v not in VALID_ARITHMETIC_JOINS: + raise ValueError( + f"Invalid arithmetic_join: {v!r}. " + f"Must be one of {VALID_ARITHMETIC_JOINS}." + ) self._current_values[k] = v - def get_value(self, name: str) -> int: + def get_value(self, name: str) -> Any: if name in self._defaults: return self._current_values[name] else: @@ -57,4 +72,8 @@ def __repr__(self) -> str: return f"OptionSettings:\n {settings}" -options = OptionSettings(display_max_rows=14, display_max_terms=6) +options = OptionSettings( + display_max_rows=14, + display_max_terms=6, + arithmetic_join="legacy", +) diff --git a/linopy/expressions.py b/linopy/expressions.py index 32bf781b..bb08870b 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -527,7 +527,6 @@ def _align_constant( other: DataArray, fill_value: float = 0, join: str | None = None, - default_join: str = "exact", ) -> tuple[DataArray, DataArray, bool]: """ Align a constant DataArray with self.const. @@ -539,10 +538,7 @@ def _align_constant( fill_value : float, default: 0 Fill value for missing coordinates. join : str, optional - Alignment method. If None, uses default_join. - default_join : str, default: "exact" - Default join mode when join is None. Use "exact" for add/sub, - "inner" for mul/div. + Alignment method. If None, uses ``options["arithmetic_join"]``. Returns ------- @@ -554,7 +550,24 @@ def _align_constant( Whether the expression's data needs reindexing. """ if join is None: - join = default_join + join = options["arithmetic_join"] + + if join == "legacy": + warn( + "The 'legacy' arithmetic join is deprecated and will be removed in a " + "future version. Set linopy.options['arithmetic_join'] = 'exact' to " + "opt in to the new behavior.", + FutureWarning, + stacklevel=4, + ) + # Old behavior: override when same sizes, left join otherwise + if other.sizes == self.const.sizes: + return self.const, other.assign_coords(coords=self.coords), False + return ( + self.const, + other.reindex_like(self.const, fill_value=fill_value), + False, + ) if join == "override": return self.const, other.assign_coords(coords=self.coords), False @@ -589,7 +602,7 @@ def _add_constant( return self.assign(const=self.const + other) da = as_dataarray(other, coords=self.coords, dims=self.coord_dims) self_const, da, needs_data_reindex = self._align_constant( - da, fill_value=0, join=join, default_join="exact" + da, fill_value=0, join=join ) if needs_data_reindex: fv = {**self._fill_value, "const": 0} @@ -610,7 +623,7 @@ def _apply_constant_op( ) -> GenericExpression: factor = as_dataarray(other, coords=self.coords, dims=self.coord_dims) self_const, factor, needs_data_reindex = self._align_constant( - factor, fill_value=fill_value, join=join, default_join="exact" + factor, fill_value=fill_value, join=join ) if needs_data_reindex: fv = {**self._fill_value, "const": 0} @@ -1093,8 +1106,35 @@ def to_constraint( f"Both sides of the constraint are constant. At least one side must contain variables. {self} {rhs}" ) + effective_join = join if join is not None else options["arithmetic_join"] + + if effective_join == "legacy": + warn( + "The 'legacy' arithmetic join is deprecated and will be removed " + "in a future version. Set linopy.options['arithmetic_join'] = " + "'exact' to opt in to the new behavior.", + FutureWarning, + stacklevel=3, + ) + # Old behavior: convert to DataArray, warn about extra dims, + # reindex_like (left join), then sub + if isinstance(rhs, SUPPORTED_CONSTANT_TYPES): + rhs = as_dataarray(rhs, coords=self.coords, dims=self.coord_dims) + extra_dims = set(rhs.dims) - set(self.coord_dims) + if extra_dims: + logger.warning( + f"Constant RHS contains dimensions {extra_dims} not present " + f"in the expression, which might lead to inefficiencies. " + f"Consider collapsing the dimensions by taking min/max." + ) + rhs = rhs.reindex_like(self.const, fill_value=np.nan) + all_to_lhs = self.sub(rhs, join=join).data + data = assign_multiindex_safe( + all_to_lhs[["coeffs", "vars"]], sign=sign, rhs=-all_to_lhs.const + ) + return constraints.Constraint(data, model=self.model) + if isinstance(rhs, DataArray): - effective_join = join if join is not None else "exact" if effective_join == "override": aligned_rhs = rhs.assign_coords(coords=self.const.coords) expr_const = self.const @@ -1127,13 +1167,6 @@ def to_constraint( expr_data[["coeffs", "vars"]], sign=sign, rhs=constraint_rhs ) return constraints.Constraint(data, model=self.model) - elif isinstance(rhs, np.ndarray | pd.Series | pd.DataFrame) and rhs.ndim > len( - self.coord_dims - ): - raise ValueError( - f"RHS has {rhs.ndim} dimensions, but the expression only " - f"has {len(self.coord_dims)}. Cannot create constraint." - ) all_to_lhs = self.sub(rhs, join=join).data data = assign_multiindex_safe( @@ -2413,10 +2446,33 @@ def merge( elif cls == variables.Variable: kwargs["fill_value"] = variables.FILL_VALUE - if join is not None: - kwargs["join"] = join + effective_join = join if join is not None else options["arithmetic_join"] + + if effective_join == "legacy": + warn( + "The 'legacy' arithmetic join is deprecated and will be removed " + "in a future version. Set linopy.options['arithmetic_join'] = " + "'exact' to opt in to the new behavior.", + FutureWarning, + stacklevel=2, + ) + # Reproduce old behavior: override when all shared dims have + # matching sizes, outer otherwise. + if cls in linopy_types and dim in HELPER_DIMS: + coord_dims = [ + {k: v for k, v in e.sizes.items() if k not in HELPER_DIMS} + for e in exprs + ] + common_keys = set.intersection(*(set(d.keys()) for d in coord_dims)) + override = all( + len({d[k] for d in coord_dims if k in d}) == 1 for k in common_keys + ) + else: + override = False + + kwargs["join"] = "override" if override else "outer" else: - kwargs["join"] = "exact" + kwargs["join"] = effective_join try: if dim == TERM_DIM: diff --git a/test/conftest.py b/test/conftest.py index 3197689b..8a4343d6 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -2,8 +2,12 @@ import os +import pandas as pd import pytest +import linopy +from linopy import Model, Variable + def pytest_addoption(parser: pytest.Parser) -> None: """Add custom command line options.""" @@ -48,3 +52,49 @@ def pytest_collection_modifyitems( if solver_supports(solver, SolverFeature.GPU_ACCELERATION): item.add_marker(skip_gpu) item.add_marker(pytest.mark.gpu) + + +@pytest.fixture +def exact_join(): + """Set arithmetic_join to 'exact' for the duration of a test.""" + linopy.options["arithmetic_join"] = "exact" + yield + linopy.options["arithmetic_join"] = "legacy" + + +@pytest.fixture +def m() -> Model: + m = Model() + m.add_variables(pd.Series([0, 0]), 1, name="x") + m.add_variables(4, pd.Series([8, 10]), name="y") + m.add_variables(0, pd.DataFrame([[1, 2], [3, 4], [5, 6]]).T, name="z") + m.add_variables(coords=[pd.RangeIndex(20, name="dim_2")], name="v") + idx = pd.MultiIndex.from_product([[1, 2], ["a", "b"]], names=("level1", "level2")) + idx.name = "dim_3" + m.add_variables(coords=[idx], name="u") + return m + + +@pytest.fixture +def x(m: Model) -> Variable: + return m.variables["x"] + + +@pytest.fixture +def y(m: Model) -> Variable: + return m.variables["y"] + + +@pytest.fixture +def z(m: Model) -> Variable: + return m.variables["z"] + + +@pytest.fixture +def v(m: Model) -> Variable: + return m.variables["v"] + + +@pytest.fixture +def u(m: Model) -> Variable: + return m.variables["u"] diff --git a/test/test_algebraic_properties.py b/test/test_algebraic_properties.py index 09548bf3..74e9e8dd 100644 --- a/test/test_algebraic_properties.py +++ b/test/test_algebraic_properties.py @@ -42,10 +42,19 @@ import pytest import xarray as xr +import linopy from linopy import Model from linopy.expressions import LinearExpression +@pytest.fixture(autouse=True) +def _use_exact_join(): + """Use exact arithmetic join for all tests in this module.""" + linopy.options["arithmetic_join"] = "exact" + yield + linopy.options["arithmetic_join"] = "legacy" + + @pytest.fixture def m(): return Model() diff --git a/test/test_common.py b/test/test_common.py index 4b84755a..72171211 100644 --- a/test/test_common.py +++ b/test/test_common.py @@ -10,10 +10,10 @@ import polars as pl import pytest import xarray as xr -from test_linear_expression import m, u, x # noqa: F401 from xarray import DataArray from xarray.testing.assertions import assert_equal +import linopy from linopy import LinearExpression, Model, Variable from linopy.common import ( align, @@ -28,6 +28,17 @@ from linopy.testing import assert_linequal, assert_varequal +@pytest.fixture(autouse=True) +def _use_exact_join(): + """Use exact arithmetic join for all tests in this module.""" + linopy.options["arithmetic_join"] = "exact" + yield + linopy.options["arithmetic_join"] = "legacy" + + +# Fixtures m, u, x are provided by conftest.py + + def test_as_dataarray_with_series_dims_default() -> None: target_dim = "dim_0" target_index = [0, 1, 2] diff --git a/test/test_common_legacy.py b/test/test_common_legacy.py new file mode 100644 index 00000000..7e623bf6 --- /dev/null +++ b/test/test_common_legacy.py @@ -0,0 +1,735 @@ +#!/usr/bin/env python3 +""" +Created on Mon Jun 19 12:11:03 2023 + +@author: fabian +""" + +import numpy as np +import pandas as pd +import polars as pl +import pytest +import xarray as xr +from xarray import DataArray +from xarray.testing.assertions import assert_equal + +from linopy import LinearExpression, Model, Variable +from linopy.common import ( + align, + as_dataarray, + assign_multiindex_safe, + best_int, + get_dims_with_index_levels, + is_constant, + iterate_slices, + maybe_group_terms_polars, +) +from linopy.testing import assert_linequal, assert_varequal + + +def test_as_dataarray_with_series_dims_default() -> None: + target_dim = "dim_0" + target_index = [0, 1, 2] + s = pd.Series([1, 2, 3]) + da = as_dataarray(s) + assert isinstance(da, DataArray) + assert da.dims == (target_dim,) + assert list(da.coords[target_dim].values) == target_index + + +def test_as_dataarray_with_series_dims_set() -> None: + target_dim = "dim1" + target_index = ["a", "b", "c"] + s = pd.Series([1, 2, 3], index=target_index) + dims = [target_dim] + da = as_dataarray(s, dims=dims) + assert isinstance(da, DataArray) + assert da.dims == (target_dim,) + assert list(da.coords[target_dim].values) == target_index + + +def test_as_dataarray_with_series_dims_given() -> None: + target_dim = "dim1" + target_index = ["a", "b", "c"] + index = pd.Index(target_index, name=target_dim) + s = pd.Series([1, 2, 3], index=index) + dims: list[str] = [] + da = as_dataarray(s, dims=dims) + assert isinstance(da, DataArray) + assert da.dims == (target_dim,) + assert list(da.coords[target_dim].values) == target_index + + +def test_as_dataarray_with_series_dims_priority() -> None: + """The dimension name from the pandas object should have priority.""" + target_dim = "dim1" + target_index = ["a", "b", "c"] + index = pd.Index(target_index, name=target_dim) + s = pd.Series([1, 2, 3], index=index) + dims = ["other"] + da = as_dataarray(s, dims=dims) + assert isinstance(da, DataArray) + assert da.dims == (target_dim,) + assert list(da.coords[target_dim].values) == target_index + + +def test_as_dataarray_with_series_dims_subset() -> None: + target_dim = "dim_0" + target_index = ["a", "b", "c"] + s = pd.Series([1, 2, 3], index=target_index) + dims: list[str] = [] + da = as_dataarray(s, dims=dims) + assert isinstance(da, DataArray) + assert da.dims == (target_dim,) + assert list(da.coords[target_dim].values) == target_index + + +def test_as_dataarray_with_series_dims_superset() -> None: + target_dim = "dim_a" + target_index = ["a", "b", "c"] + s = pd.Series([1, 2, 3], index=target_index) + dims = [target_dim, "other"] + da = as_dataarray(s, dims=dims) + assert isinstance(da, DataArray) + assert da.dims == (target_dim,) + assert list(da.coords[target_dim].values) == target_index + + +def test_as_dataarray_with_series_aligned_coords() -> None: + """This should not give out a warning even though coords are given.""" + target_dim = "dim_0" + target_index = ["a", "b", "c"] + s = pd.Series([1, 2, 3], index=target_index) + da = as_dataarray(s, coords=[target_index]) + assert isinstance(da, DataArray) + assert da.dims == (target_dim,) + assert list(da.coords[target_dim].values) == target_index + + da = as_dataarray(s, coords={target_dim: target_index}) + assert isinstance(da, DataArray) + assert da.dims == (target_dim,) + assert list(da.coords[target_dim].values) == target_index + + +def test_as_dataarray_with_pl_series_dims_default() -> None: + target_dim = "dim_0" + target_index = [0, 1, 2] + s = pl.Series([1, 2, 3]) + da = as_dataarray(s) + assert isinstance(da, DataArray) + assert da.dims == (target_dim,) + assert list(da.coords[target_dim].values) == target_index + + +def test_as_dataarray_dataframe_dims_default() -> None: + target_dims = ("dim_0", "dim_1") + target_index = [0, 1] + target_columns = ["A", "B"] + df = pd.DataFrame([[1, 2], [3, 4]], index=target_index, columns=target_columns) + da = as_dataarray(df) + assert isinstance(da, DataArray) + assert da.dims == target_dims + assert list(da.coords[target_dims[0]].values) == target_index + assert list(da.coords[target_dims[1]].values) == target_columns + + +def test_as_dataarray_dataframe_dims_set() -> None: + target_dims = ("dim1", "dim2") + target_index = ["a", "b"] + target_columns = ["A", "B"] + df = pd.DataFrame([[1, 2], [3, 4]], index=target_index, columns=target_columns) + da = as_dataarray(df, dims=target_dims) + assert isinstance(da, DataArray) + assert da.dims == target_dims + assert list(da.coords[target_dims[0]].values) == target_index + assert list(da.coords[target_dims[1]].values) == target_columns + + +def test_as_dataarray_dataframe_dims_given() -> None: + target_dims = ("dim1", "dim2") + target_index = ["a", "b"] + target_columns = ["A", "B"] + index = pd.Index(target_index, name=target_dims[0]) + columns = pd.Index(target_columns, name=target_dims[1]) + df = pd.DataFrame([[1, 2], [3, 4]], index=index, columns=columns) + dims: list[str] = [] + da = as_dataarray(df, dims=dims) + assert isinstance(da, DataArray) + assert da.dims == target_dims + assert list(da.coords[target_dims[0]].values) == target_index + assert list(da.coords[target_dims[1]].values) == target_columns + + +def test_as_dataarray_dataframe_dims_priority() -> None: + """The dimension name from the pandas object should have priority.""" + target_dims = ("dim1", "dim2") + target_index = ["a", "b"] + target_columns = ["A", "B"] + index = pd.Index(target_index, name=target_dims[0]) + columns = pd.Index(target_columns, name=target_dims[1]) + df = pd.DataFrame([[1, 2], [3, 4]], index=index, columns=columns) + dims = ["other"] + da = as_dataarray(df, dims=dims) + assert isinstance(da, DataArray) + assert da.dims == target_dims + assert list(da.coords[target_dims[0]].values) == target_index + assert list(da.coords[target_dims[1]].values) == target_columns + + +def test_as_dataarray_dataframe_dims_subset() -> None: + target_dims = ("dim_0", "dim_1") + target_index = ["a", "b"] + target_columns = ["A", "B"] + df = pd.DataFrame([[1, 2], [3, 4]], index=target_index, columns=target_columns) + dims: list[str] = [] + da = as_dataarray(df, dims=dims) + assert isinstance(da, DataArray) + assert da.dims == target_dims + assert list(da.coords[target_dims[0]].values) == target_index + assert list(da.coords[target_dims[1]].values) == target_columns + + +def test_as_dataarray_dataframe_dims_superset() -> None: + target_dims = ("dim_a", "dim_b") + target_index = ["a", "b"] + target_columns = ["A", "B"] + df = pd.DataFrame([[1, 2], [3, 4]], index=target_index, columns=target_columns) + dims = [*target_dims, "other"] + da = as_dataarray(df, dims=dims) + assert isinstance(da, DataArray) + assert da.dims == target_dims + assert list(da.coords[target_dims[0]].values) == target_index + assert list(da.coords[target_dims[1]].values) == target_columns + + +def test_as_dataarray_dataframe_aligned_coords() -> None: + """This should not give out a warning even though coords are given.""" + target_dims = ("dim_0", "dim_1") + target_index = ["a", "b"] + target_columns = ["A", "B"] + df = pd.DataFrame([[1, 2], [3, 4]], index=target_index, columns=target_columns) + da = as_dataarray(df, coords=[target_index, target_columns]) + assert isinstance(da, DataArray) + assert da.dims == target_dims + assert list(da.coords[target_dims[0]].values) == target_index + assert list(da.coords[target_dims[1]].values) == target_columns + + coords = dict(zip(target_dims, [target_index, target_columns])) + da = as_dataarray(df, coords=coords) + assert isinstance(da, DataArray) + assert da.dims == target_dims + assert list(da.coords[target_dims[0]].values) == target_index + assert list(da.coords[target_dims[1]].values) == target_columns + + +def test_as_dataarray_with_ndarray_no_coords_no_dims() -> None: + target_dims = ("dim_0", "dim_1") + target_coords = [[0, 1], [0, 1]] + arr = np.array([[1, 2], [3, 4]]) + da = as_dataarray(arr) + assert isinstance(da, DataArray) + assert da.dims == target_dims + for i, dim in enumerate(target_dims): + assert list(da.coords[dim]) == target_coords[i] + + +def test_as_dataarray_with_ndarray_coords_list_no_dims() -> None: + target_dims = ("dim_0", "dim_1") + target_coords = [["a", "b"], ["A", "B"]] + arr = np.array([[1, 2], [3, 4]]) + da = as_dataarray(arr, coords=target_coords) + assert isinstance(da, DataArray) + assert da.dims == target_dims + for i, dim in enumerate(target_dims): + assert list(da.coords[dim]) == target_coords[i] + + +def test_as_dataarray_with_ndarray_coords_indexes_no_dims() -> None: + target_dims = ("dim1", "dim2") + target_coords = [ + pd.Index(["a", "b"], name="dim1"), + pd.Index(["A", "B"], name="dim2"), + ] + arr = np.array([[1, 2], [3, 4]]) + da = as_dataarray(arr, coords=target_coords) + assert isinstance(da, DataArray) + assert da.dims == target_dims + for i, dim in enumerate(target_dims): + assert list(da.coords[dim]) == list(target_coords[i]) + + +def test_as_dataarray_with_ndarray_coords_dict_set_no_dims() -> None: + """If no dims are given and coords are a dict, the keys of the dict should be used as dims.""" + target_dims = ("dim_0", "dim_2") + target_coords = {"dim_0": ["a", "b"], "dim_2": ["A", "B"]} + arr = np.array([[1, 2], [3, 4]]) + da = as_dataarray(arr, coords=target_coords) + assert isinstance(da, DataArray) + assert da.dims == target_dims + for dim in target_dims: + assert list(da.coords[dim]) == target_coords[dim] + + +def test_as_dataarray_with_ndarray_coords_list_dims() -> None: + target_dims = ("dim1", "dim2") + target_coords = [["a", "b"], ["A", "B"]] + arr = np.array([[1, 2], [3, 4]]) + da = as_dataarray(arr, coords=target_coords, dims=target_dims) + assert isinstance(da, DataArray) + assert da.dims == target_dims + for i, dim in enumerate(target_dims): + assert list(da.coords[dim]) == target_coords[i] + + +def test_as_dataarray_with_ndarray_coords_list_dims_superset() -> None: + target_dims = ("dim1", "dim2") + target_coords = [["a", "b"], ["A", "B"]] + arr = np.array([[1, 2], [3, 4]]) + dims = [*target_dims, "dim3"] + da = as_dataarray(arr, coords=target_coords, dims=dims) + assert isinstance(da, DataArray) + assert da.dims == target_dims + for i, dim in enumerate(target_dims): + assert list(da.coords[dim]) == target_coords[i] + + +def test_as_dataarray_with_ndarray_coords_list_dims_subset() -> None: + target_dims = ("dim0", "dim_1") + target_coords = [["a", "b"], ["A", "B"]] + arr = np.array([[1, 2], [3, 4]]) + dims = ["dim0"] + da = as_dataarray(arr, coords=target_coords, dims=dims) + assert isinstance(da, DataArray) + assert da.dims == target_dims + for i, dim in enumerate(target_dims): + assert list(da.coords[dim]) == target_coords[i] + + +def test_as_dataarray_with_ndarray_coords_indexes_dims_aligned() -> None: + target_dims = ("dim1", "dim2") + target_coords = [ + pd.Index(["a", "b"], name="dim1"), + pd.Index(["A", "B"], name="dim2"), + ] + arr = np.array([[1, 2], [3, 4]]) + da = as_dataarray(arr, coords=target_coords, dims=target_dims) + assert isinstance(da, DataArray) + assert da.dims == target_dims + for i, dim in enumerate(target_dims): + assert list(da.coords[dim]) == list(target_coords[i]) + + +def test_as_dataarray_with_ndarray_coords_indexes_dims_not_aligned() -> None: + target_dims = ("dim3", "dim4") + target_coords = [ + pd.Index(["a", "b"], name="dim1"), + pd.Index(["A", "B"], name="dim2"), + ] + arr = np.array([[1, 2], [3, 4]]) + with pytest.raises(ValueError): + as_dataarray(arr, coords=target_coords, dims=target_dims) + + +def test_as_dataarray_with_ndarray_coords_dict_dims_aligned() -> None: + target_dims = ("dim_0", "dim_1") + target_coords = {"dim_0": ["a", "b"], "dim_1": ["A", "B"]} + arr = np.array([[1, 2], [3, 4]]) + da = as_dataarray(arr, coords=target_coords, dims=target_dims) + assert isinstance(da, DataArray) + assert da.dims == target_dims + for dim in target_dims: + assert list(da.coords[dim]) == target_coords[dim] + + +def test_as_dataarray_with_ndarray_coords_dict_set_dims_not_aligned() -> None: + target_dims = ("dim_0", "dim_1") + target_coords = {"dim_0": ["a", "b"], "dim_2": ["A", "B"]} + arr = np.array([[1, 2], [3, 4]]) + da = as_dataarray(arr, coords=target_coords, dims=target_dims) + assert da.dims == target_dims + assert list(da.coords["dim_0"].values) == ["a", "b"] + assert "dim_2" not in da.coords + + +def test_as_dataarray_with_number() -> None: + num = 1 + da = as_dataarray(num, dims=["dim1"], coords=[["a"]]) + assert isinstance(da, DataArray) + assert da.dims == ("dim1",) + assert list(da.coords["dim1"].values) == ["a"] + + +def test_as_dataarray_with_np_number() -> None: + num = np.float64(1) + da = as_dataarray(num, dims=["dim1"], coords=[["a"]]) + assert isinstance(da, DataArray) + assert da.dims == ("dim1",) + assert list(da.coords["dim1"].values) == ["a"] + + +def test_as_dataarray_with_number_default_dims_coords() -> None: + num = 1 + da = as_dataarray(num) + assert isinstance(da, DataArray) + assert da.dims == () + assert da.coords == {} + + +def test_as_dataarray_with_number_and_coords() -> None: + num = 1 + da = as_dataarray(num, coords=[pd.RangeIndex(10, name="a")]) + assert isinstance(da, DataArray) + assert da.dims == ("a",) + assert list(da.coords["a"].values) == list(range(10)) + + +def test_as_dataarray_with_dataarray() -> None: + da_in = DataArray( + data=[[1, 2], [3, 4]], + dims=["dim1", "dim2"], + coords={"dim1": ["a", "b"], "dim2": ["A", "B"]}, + ) + da_out = as_dataarray(da_in, dims=["dim1", "dim2"], coords=[["a", "b"], ["A", "B"]]) + assert isinstance(da_out, DataArray) + assert da_out.dims == da_in.dims + assert list(da_out.coords["dim1"].values) == list(da_in.coords["dim1"].values) + assert list(da_out.coords["dim2"].values) == list(da_in.coords["dim2"].values) + + +def test_as_dataarray_with_dataarray_default_dims_coords() -> None: + da_in = DataArray( + data=[[1, 2], [3, 4]], + dims=["dim1", "dim2"], + coords={"dim1": ["a", "b"], "dim2": ["A", "B"]}, + ) + da_out = as_dataarray(da_in) + assert isinstance(da_out, DataArray) + assert da_out.dims == da_in.dims + assert list(da_out.coords["dim1"].values) == list(da_in.coords["dim1"].values) + assert list(da_out.coords["dim2"].values) == list(da_in.coords["dim2"].values) + + +def test_as_dataarray_with_unsupported_type() -> None: + with pytest.raises(TypeError): + as_dataarray(lambda x: 1, dims=["dim1"], coords=[["a"]]) + + +def test_best_int() -> None: + # Test for int8 + assert best_int(127) == np.int8 + # Test for int16 + assert best_int(128) == np.int16 + assert best_int(32767) == np.int16 + # Test for int32 + assert best_int(32768) == np.int32 + assert best_int(2147483647) == np.int32 + # Test for int64 + assert best_int(2147483648) == np.int64 + assert best_int(9223372036854775807) == np.int64 + + # Test for value too large + with pytest.raises( + ValueError, match=r"Value 9223372036854775808 is too large for int64." + ): + best_int(9223372036854775808) + + +def test_assign_multiindex_safe() -> None: + # Create a multi-indexed dataset + index = pd.MultiIndex.from_product([["A", "B"], [1, 2]], names=["letter", "number"]) + data = xr.DataArray([1, 2, 3, 4], dims=["index"], coords={"index": index}) + ds = xr.Dataset({"value": data}) + + # This would now warn about the index deletion of single index level + # ds["humidity"] = data + + # Case 1: Assigning a single DataArray + result = assign_multiindex_safe(ds, humidity=data) + assert "humidity" in result + assert "value" in result + assert result["humidity"].equals(data) + + # Case 2: Assigning a Dataset + result = assign_multiindex_safe(ds, **xr.Dataset({"humidity": data})) # type: ignore + assert "humidity" in result + assert "value" in result + assert result["humidity"].equals(data) + + # Case 3: Assigning multiple DataArrays + result = assign_multiindex_safe(ds, humidity=data, pressure=data) + assert "humidity" in result + assert "pressure" in result + assert "value" in result + assert result["humidity"].equals(data) + assert result["pressure"].equals(data) + + +def test_iterate_slices_basic() -> None: + ds = xr.Dataset( + {"var": (("x", "y"), np.random.rand(10, 10))}, # noqa: NPY002 + coords={"x": np.arange(10), "y": np.arange(10)}, + ) + slices = list(iterate_slices(ds, slice_size=20)) + assert len(slices) == 5 + for s in slices: + assert isinstance(s, xr.Dataset) + assert set(s.dims) == set(ds.dims) + + +def test_iterate_slices_with_exclude_dims() -> None: + ds = xr.Dataset( + {"var": (("x", "y"), np.random.rand(10, 20))}, # noqa: NPY002 + coords={"x": np.arange(10), "y": np.arange(20)}, + ) + slices = list(iterate_slices(ds, slice_size=20, slice_dims=["x"])) + assert len(slices) == 10 + for s in slices: + assert isinstance(s, xr.Dataset) + assert set(s.dims) == set(ds.dims) + + +def test_iterate_slices_large_max_size() -> None: + ds = xr.Dataset( + {"var": (("x", "y"), np.random.rand(10, 10))}, # noqa: NPY002 + coords={"x": np.arange(10), "y": np.arange(10)}, + ) + slices = list(iterate_slices(ds, slice_size=200)) + assert len(slices) == 1 + for s in slices: + assert isinstance(s, xr.Dataset) + assert set(s.dims) == set(ds.dims) + + +def test_iterate_slices_small_max_size() -> None: + ds = xr.Dataset( + {"var": (("x", "y"), np.random.rand(10, 20))}, # noqa: NPY002 + coords={"x": np.arange(10), "y": np.arange(20)}, + ) + slices = list(iterate_slices(ds, slice_size=8, slice_dims=["x"])) + assert ( + len(slices) == 10 + ) # goes to the smallest slice possible which is 1 for the x dimension + for s in slices: + assert isinstance(s, xr.Dataset) + assert set(s.dims) == set(ds.dims) + + +def test_iterate_slices_slice_size_none() -> None: + ds = xr.Dataset( + {"var": (("x", "y"), np.random.rand(10, 10))}, # noqa: NPY002 + coords={"x": np.arange(10), "y": np.arange(10)}, + ) + slices = list(iterate_slices(ds, slice_size=None)) + assert len(slices) == 1 + for s in slices: + assert ds.equals(s) + + +def test_iterate_slices_includes_last_slice() -> None: + ds = xr.Dataset( + {"var": (("x"), np.random.rand(10))}, # noqa: NPY002 + coords={"x": np.arange(10)}, + ) + slices = list(iterate_slices(ds, slice_size=3, slice_dims=["x"])) + assert len(slices) == 4 # 10 slices for dimension 'x' with size 10 + total_elements = sum(s.sizes["x"] for s in slices) + assert total_elements == ds.sizes["x"] # Ensure all elements are included + for s in slices: + assert isinstance(s, xr.Dataset) + assert set(s.dims) == set(ds.dims) + + +def test_iterate_slices_empty_slice_dims() -> None: + ds = xr.Dataset( + {"var": (("x", "y"), np.random.rand(10, 10))}, # noqa: NPY002 + coords={"x": np.arange(10), "y": np.arange(10)}, + ) + slices = list(iterate_slices(ds, slice_size=50, slice_dims=[])) + assert len(slices) == 1 + for s in slices: + assert ds.equals(s) + + +def test_iterate_slices_invalid_slice_dims() -> None: + ds = xr.Dataset( + {"var": (("x", "y"), np.random.rand(10, 10))}, # noqa: NPY002 + coords={"x": np.arange(10), "y": np.arange(10)}, + ) + with pytest.raises(ValueError): + list(iterate_slices(ds, slice_size=50, slice_dims=["z"])) + + +def test_iterate_slices_empty_dataset() -> None: + ds = xr.Dataset( + {"var": (("x", "y"), np.array([]).reshape(0, 0))}, coords={"x": [], "y": []} + ) + slices = list(iterate_slices(ds, slice_size=10, slice_dims=["x"])) + assert len(slices) == 1 + assert ds.equals(slices[0]) + + +def test_iterate_slices_single_element() -> None: + ds = xr.Dataset({"var": (("x", "y"), np.array([[1]]))}, coords={"x": [0], "y": [0]}) + slices = list(iterate_slices(ds, slice_size=1, slice_dims=["x"])) + assert len(slices) == 1 + assert ds.equals(slices[0]) + + +def test_get_dims_with_index_levels() -> None: + # Create test data + + # Case 1: Simple dataset with regular dimensions + ds1 = xr.Dataset( + {"temp": (("time", "lat"), np.random.rand(3, 2))}, # noqa: NPY002 + coords={"time": pd.date_range("2024-01-01", periods=3), "lat": [0, 1]}, + ) + + # Case 2: Dataset with a multi-index dimension + stations_index = pd.MultiIndex.from_product( + [["USA", "Canada"], ["NYC", "Toronto"]], names=["country", "city"] + ) + stations_coords = xr.Coordinates.from_pandas_multiindex(stations_index, "station") + ds2 = xr.Dataset( + {"temp": (("time", "station"), np.random.rand(3, 4))}, # noqa: NPY002 + coords={"time": pd.date_range("2024-01-01", periods=3), **stations_coords}, + ) + + # Case 3: Dataset with unnamed multi-index levels + unnamed_stations_index = pd.MultiIndex.from_product( + [["USA", "Canada"], ["NYC", "Toronto"]] + ) + unnamed_stations_coords = xr.Coordinates.from_pandas_multiindex( + unnamed_stations_index, "station" + ) + ds3 = xr.Dataset( + {"temp": (("time", "station"), np.random.rand(3, 4))}, # noqa: NPY002 + coords={ + "time": pd.date_range("2024-01-01", periods=3), + **unnamed_stations_coords, + }, + ) + + # Case 4: Dataset with multiple multi-indexed dimensions + locations_index = pd.MultiIndex.from_product( + [["North", "South"], ["A", "B"]], names=["region", "site"] + ) + locations_coords = xr.Coordinates.from_pandas_multiindex( + locations_index, "location" + ) + + ds4 = xr.Dataset( + {"temp": (("time", "station", "location"), np.random.rand(2, 4, 4))}, # noqa: NPY002 + coords={ + "time": pd.date_range("2024-01-01", periods=2), + **stations_coords, + **locations_coords, + }, + ) + + # Run tests + + # Test case 1: Regular dimensions + assert get_dims_with_index_levels(ds1) == ["time", "lat"] + + # Test case 2: Named multi-index + assert get_dims_with_index_levels(ds2) == ["time", "station (country, city)"] + + # Test case 3: Unnamed multi-index + assert get_dims_with_index_levels(ds3) == [ + "time", + "station (station_level_0, station_level_1)", + ] + + # Test case 4: Multiple multi-indices + expected = ["time", "station (country, city)", "location (region, site)"] + assert get_dims_with_index_levels(ds4) == expected + + # Test case 5: Empty dataset + ds5 = xr.Dataset() + assert get_dims_with_index_levels(ds5) == [] + + +@pytest.mark.xfail(reason="xarray MultiIndex alignment incompatibility") +def test_align(x: Variable, u: Variable) -> None: # noqa: F811 + alpha = xr.DataArray([1, 2], [[1, 2]]) + beta = xr.DataArray( + [1, 2, 3], + [ + ( + "dim_3", + pd.MultiIndex.from_tuples( + [(1, "b"), (2, "b"), (1, "c")], names=["level1", "level2"] + ), + ) + ], + ) + + # inner join + x_obs, alpha_obs = align(x, alpha) + assert isinstance(x_obs, Variable) + assert x_obs.shape == alpha_obs.shape == (1,) + assert_varequal(x_obs, x.loc[[1]]) + + # left-join + x_obs, alpha_obs = align(x, alpha, join="left") + assert x_obs.shape == alpha_obs.shape == (2,) + assert isinstance(x_obs, Variable) + assert_varequal(x_obs, x) + assert_equal(alpha_obs, DataArray([np.nan, 1], [[0, 1]])) + + # multiindex + beta_obs, u_obs = align(beta, u) + assert u_obs.shape == beta_obs.shape == (2,) + assert isinstance(u_obs, Variable) + assert_varequal(u_obs, u.loc[[(1, "b"), (2, "b")]]) + assert_equal(beta_obs, beta.loc[[(1, "b"), (2, "b")]]) + + # with linear expression + expr = 20 * x + x_obs, expr_obs, alpha_obs = align(x, expr, alpha) + assert x_obs.shape == alpha_obs.shape == (1,) + assert expr_obs.shape == (1, 1) # _term dim + assert isinstance(expr_obs, LinearExpression) + assert_linequal(expr_obs, expr.loc[[1]]) + + +def test_is_constant() -> None: + model = Model() + index = pd.Index(range(10), name="t") + a = model.add_variables(name="a", coords=[index]) + b = a.sel(t=1) + c = a * 2 + d = a * a + + non_constant = [a, b, c, d] + for nc in non_constant: + assert not is_constant(nc) + + constant_values = [ + 5, + 3.14, + np.int32(7), + np.float64(2.71), + pd.Series([1, 2, 3]), + np.array([4, 5, 6]), + xr.DataArray([k for k in range(10)], coords=[index]), + ] + for cv in constant_values: + assert is_constant(cv) + + +def test_maybe_group_terms_polars_no_duplicates() -> None: + """Fast path: distinct (labels, vars) pairs skip group_by.""" + df = pl.DataFrame({"labels": [0, 0], "vars": [1, 2], "coeffs": [3.0, 4.0]}) + result = maybe_group_terms_polars(df) + assert result.shape == (2, 3) + assert result.columns == ["labels", "vars", "coeffs"] + assert result["coeffs"].to_list() == [3.0, 4.0] + + +def test_maybe_group_terms_polars_with_duplicates() -> None: + """Slow path: duplicate (labels, vars) pairs trigger group_by.""" + df = pl.DataFrame({"labels": [0, 0], "vars": [1, 1], "coeffs": [3.0, 4.0]}) + result = maybe_group_terms_polars(df) + assert result.shape == (1, 3) + assert result["coeffs"].to_list() == [7.0] diff --git a/test/test_constraints.py b/test/test_constraints.py index b20b18cf..55f92f6e 100644 --- a/test/test_constraints.py +++ b/test/test_constraints.py @@ -12,9 +12,19 @@ import pytest import xarray as xr +import linopy from linopy import EQUAL, GREATER_EQUAL, LESS_EQUAL, Model from linopy.testing import assert_conequal + +@pytest.fixture(autouse=True) +def _use_exact_join(): + """Use exact arithmetic join for all tests in this module.""" + linopy.options["arithmetic_join"] = "exact" + yield + linopy.options["arithmetic_join"] = "legacy" + + # Test model functions diff --git a/test/test_constraints_legacy.py b/test/test_constraints_legacy.py new file mode 100644 index 00000000..9a467c8c --- /dev/null +++ b/test/test_constraints_legacy.py @@ -0,0 +1,448 @@ +#!/usr/bin/env python3 +""" +Created on Wed Mar 10 11:23:13 2021. + +@author: fabulous +""" + +from typing import Any + +import dask +import dask.array.core +import numpy as np +import pandas as pd +import pytest +import xarray as xr + +from linopy import EQUAL, GREATER_EQUAL, LESS_EQUAL, Model, Variable, available_solvers +from linopy.testing import assert_conequal + +# Test model functions + + +def test_constraint_assignment() -> None: + m: Model = Model() + + lower: xr.DataArray = xr.DataArray( + np.zeros((10, 10)), coords=[range(10), range(10)] + ) + upper: xr.DataArray = xr.DataArray(np.ones((10, 10)), coords=[range(10), range(10)]) + x = m.add_variables(lower, upper, name="x") + y = m.add_variables(name="y") + + con0 = m.add_constraints(1 * x + 10 * y, EQUAL, 0) + + for attr in m.constraints.dataset_attrs: + assert "con0" in getattr(m.constraints, attr) + + assert m.constraints.labels.con0.shape == (10, 10) + assert m.constraints.labels.con0.dtype == int + assert m.constraints.coeffs.con0.dtype in (int, float) + assert m.constraints.vars.con0.dtype in (int, float) + assert m.constraints.rhs.con0.dtype in (int, float) + + assert_conequal(m.constraints.con0, con0) + + +def test_constraint_equality() -> None: + m: Model = Model() + + lower: xr.DataArray = xr.DataArray( + np.zeros((10, 10)), coords=[range(10), range(10)] + ) + upper: xr.DataArray = xr.DataArray(np.ones((10, 10)), coords=[range(10), range(10)]) + x = m.add_variables(lower, upper, name="x") + y = m.add_variables(name="y") + + con0 = m.add_constraints(1 * x + 10 * y, EQUAL, 0) + + assert_conequal(con0, 1 * x + 10 * y == 0, strict=False) + assert_conequal(1 * x + 10 * y == 0, 1 * x + 10 * y == 0, strict=False) + + with pytest.raises(AssertionError): + assert_conequal(con0, 1 * x + 10 * y <= 0, strict=False) + + with pytest.raises(AssertionError): + assert_conequal(con0, 1 * x + 10 * y >= 0, strict=False) + + with pytest.raises(AssertionError): + assert_conequal(10 * y + 2 * x == 0, 1 * x + 10 * y == 0, strict=False) + + +def test_constraints_getattr_formatted() -> None: + m: Model = Model() + x = m.add_variables(0, 10, name="x") + m.add_constraints(1 * x == 0, name="con-0") + assert_conequal(m.constraints.con_0, m.constraints["con-0"]) + + +def test_anonymous_constraint_assignment() -> None: + m: Model = Model() + + lower = xr.DataArray(np.zeros((10, 10)), coords=[range(10), range(10)]) + upper = xr.DataArray(np.ones((10, 10)), coords=[range(10), range(10)]) + x = m.add_variables(lower, upper, name="x") + y = m.add_variables(name="y") + con = 1 * x + 10 * y == 0 + m.add_constraints(con) + + for attr in m.constraints.dataset_attrs: + assert "con0" in getattr(m.constraints, attr) + + assert m.constraints.labels.con0.shape == (10, 10) + assert m.constraints.labels.con0.dtype == int + assert m.constraints.coeffs.con0.dtype in (int, float) + assert m.constraints.vars.con0.dtype in (int, float) + assert m.constraints.rhs.con0.dtype in (int, float) + + +def test_constraint_assignment_with_tuples() -> None: + m: Model = Model() + + lower = xr.DataArray(np.zeros((10, 10)), coords=[range(10), range(10)]) + upper = xr.DataArray(np.ones((10, 10)), coords=[range(10), range(10)]) + x = m.add_variables(lower, upper) + y = m.add_variables() + + m.add_constraints([(1, x), (10, y)], EQUAL, 0, name="c") + for attr in m.constraints.dataset_attrs: + assert "c" in getattr(m.constraints, attr) + assert m.constraints.labels.c.shape == (10, 10) + + +def test_constraint_assignment_chunked() -> None: + # setting bounds with one pd.DataFrame and one pd.Series + m: Model = Model(chunk=5) + lower = pd.DataFrame(np.zeros((10, 10))) + upper = pd.Series(np.ones(10)) + x = m.add_variables(lower, upper) + m.add_constraints(x, GREATER_EQUAL, 0, name="c") + assert m.constraints.coeffs.c.data.shape == ( + 10, + 10, + 1, + ) + assert isinstance(m.constraints.coeffs.c.data, dask.array.core.Array) + + +def test_constraint_assignment_with_reindex() -> None: + m: Model = Model() + + lower = xr.DataArray(np.zeros((10, 10)), coords=[range(10), range(10)]) + upper = xr.DataArray(np.ones((10, 10)), coords=[range(10), range(10)]) + x = m.add_variables(lower, upper, name="x") + y = m.add_variables(name="y") + + m.add_constraints(1 * x + 10 * y, EQUAL, 0) + + shuffled_coords = [2, 1, 3, 4, 6, 5, 7, 9, 8, 0] + + con = x.loc[shuffled_coords] + y >= 10 + assert (con.coords["dim_0"].values == shuffled_coords).all() + + +@pytest.mark.parametrize( + "rhs_factory", + [ + pytest.param(lambda m, v: v, id="numpy"), + pytest.param(lambda m, v: xr.DataArray(v, dims=["dim_0"]), id="dataarray"), + pytest.param(lambda m, v: pd.Series(v, index=v), id="series"), + pytest.param( + lambda m, v: m.add_variables(coords=[v]), + id="variable", + ), + pytest.param( + lambda m, v: 2 * m.add_variables(coords=[v]) + 1, + id="linexpr", + ), + ], +) +def test_constraint_rhs_lower_dim(rhs_factory: Any) -> None: + m = Model() + naxis = np.arange(10, dtype=float) + maxis = np.arange(10).astype(str) + x = m.add_variables(coords=[naxis, maxis]) + y = m.add_variables(coords=[naxis, maxis]) + + c = m.add_constraints(x - y >= rhs_factory(m, naxis)) + assert c.shape == (10, 10) + + +@pytest.mark.parametrize( + "rhs_factory", + [ + pytest.param(lambda m: np.ones((5, 3)), id="numpy"), + pytest.param(lambda m: pd.DataFrame(np.ones((5, 3))), id="dataframe"), + ], +) +def test_constraint_rhs_higher_dim_constant_warns( + rhs_factory: Any, caplog: Any +) -> None: + m = Model() + x = m.add_variables(coords=[range(5)], name="x") + + with caplog.at_level("WARNING", logger="linopy.expressions"): + m.add_constraints(x >= rhs_factory(m)) + assert "dimensions" in caplog.text + + +def test_constraint_rhs_higher_dim_dataarray_reindexes() -> None: + """DataArray RHS with extra dims reindexes to expression coords (no raise).""" + m = Model() + x = m.add_variables(coords=[range(5)], name="x") + rhs = xr.DataArray(np.ones((5, 3)), dims=["dim_0", "extra"]) + + c = m.add_constraints(x >= rhs) + assert c.shape == (5, 3) + + +@pytest.mark.parametrize( + "rhs_factory", + [ + pytest.param( + lambda m: m.add_variables(coords=[range(5), range(3)]), + id="variable", + ), + pytest.param( + lambda m: 2 * m.add_variables(coords=[range(5), range(3)]) + 1, + id="linexpr", + ), + ], +) +def test_constraint_rhs_higher_dim_expression(rhs_factory: Any) -> None: + m = Model() + x = m.add_variables(coords=[range(5)], name="x") + + c = m.add_constraints(x >= rhs_factory(m)) + assert c.shape == (5, 3) + + +def test_wrong_constraint_assignment_repeated() -> None: + # repeated variable assignment is forbidden + m: Model = Model() + x = m.add_variables() + m.add_constraints(x, LESS_EQUAL, 0, name="con") + with pytest.raises(ValueError): + m.add_constraints(x, LESS_EQUAL, 0, name="con") + + +def test_masked_constraints() -> None: + m: Model = Model() + + lower = xr.DataArray(np.zeros((10, 10)), coords=[range(10), range(10)]) + upper = xr.DataArray(np.ones((10, 10)), coords=[range(10), range(10)]) + x = m.add_variables(lower, upper) + y = m.add_variables() + + mask = pd.Series([True] * 5 + [False] * 5) + m.add_constraints(1 * x + 10 * y, EQUAL, 0, mask=mask) + assert (m.constraints.labels.con0[0:5, :] != -1).all() + assert (m.constraints.labels.con0[5:10, :] == -1).all() + + +def test_masked_constraints_broadcast() -> None: + m: Model = Model() + + lower = xr.DataArray(np.zeros((10, 10)), coords=[range(10), range(10)]) + upper = xr.DataArray(np.ones((10, 10)), coords=[range(10), range(10)]) + x = m.add_variables(lower, upper) + y = m.add_variables() + + mask = pd.Series([True] * 5 + [False] * 5) + m.add_constraints(1 * x + 10 * y, EQUAL, 0, name="bc1", mask=mask) + assert (m.constraints.labels.bc1[0:5, :] != -1).all() + assert (m.constraints.labels.bc1[5:10, :] == -1).all() + + mask2 = xr.DataArray([True] * 5 + [False] * 5, dims=["dim_1"]) + m.add_constraints(1 * x + 10 * y, EQUAL, 0, name="bc2", mask=mask2) + assert (m.constraints.labels.bc2[:, 0:5] != -1).all() + assert (m.constraints.labels.bc2[:, 5:10] == -1).all() + + mask3 = xr.DataArray( + [True, True, False, False, False], + dims=["dim_0"], + coords={"dim_0": range(5)}, + ) + with pytest.warns(FutureWarning, match="Missing values will be filled"): + m.add_constraints(1 * x + 10 * y, EQUAL, 0, name="bc3", mask=mask3) + assert (m.constraints.labels.bc3[0:2, :] != -1).all() + assert (m.constraints.labels.bc3[2:5, :] == -1).all() + assert (m.constraints.labels.bc3[5:10, :] == -1).all() + + # Mask with extra dimension not in data should raise + mask4 = xr.DataArray([True, False], dims=["extra_dim"]) + with pytest.raises(AssertionError, match="not a subset"): + m.add_constraints(1 * x + 10 * y, EQUAL, 0, name="bc4", mask=mask4) + + +def test_non_aligned_constraints() -> None: + m: Model = Model() + + lower = xr.DataArray(np.zeros(10), coords=[range(10)]) + x = m.add_variables(lower, name="x") + + lower = xr.DataArray(np.zeros(8), coords=[range(8)]) + y = m.add_variables(lower, name="y") + + m.add_constraints(x == 0.0) + m.add_constraints(y == 0.0) + + with pytest.warns(UserWarning): + m.constraints.labels + + for dtype in m.constraints.labels.dtypes.values(): + assert np.issubdtype(dtype, np.integer) + + for dtype in m.constraints.coeffs.dtypes.values(): + assert np.issubdtype(dtype, np.floating) + + for dtype in m.constraints.vars.dtypes.values(): + assert np.issubdtype(dtype, np.integer) + + for dtype in m.constraints.rhs.dtypes.values(): + assert np.issubdtype(dtype, np.floating) + + +def test_constraints_flat() -> None: + m: Model = Model() + + lower = xr.DataArray(np.zeros((10, 10)), coords=[range(10), range(10)]) + upper = xr.DataArray(np.ones((10, 10)), coords=[range(10), range(10)]) + x = m.add_variables(lower, upper) + y = m.add_variables() + + assert isinstance(m.constraints.flat, pd.DataFrame) + assert m.constraints.flat.empty + with pytest.raises(ValueError): + m.constraints.to_matrix() + + m.add_constraints(1 * x + 10 * y, EQUAL, 0) + m.add_constraints(1 * x + 10 * y, LESS_EQUAL, 0) + m.add_constraints(1 * x + 10 * y, GREATER_EQUAL, 0) + + assert isinstance(m.constraints.flat, pd.DataFrame) + assert not m.constraints.flat.empty + + +def test_sanitize_infinities() -> None: + m: Model = Model() + + lower = xr.DataArray(np.zeros((10, 10)), coords=[range(10), range(10)]) + upper = xr.DataArray(np.ones((10, 10)), coords=[range(10), range(10)]) + x = m.add_variables(lower, upper, name="x") + y = m.add_variables(name="y") + + # Test correct infinities + m.add_constraints(x <= np.inf, name="con_inf") + m.add_constraints(y >= -np.inf, name="con_neg_inf") + m.constraints.sanitize_infinities() + assert (m.constraints["con_inf"].labels == -1).all() + assert (m.constraints["con_neg_inf"].labels == -1).all() + + # Test incorrect infinities + with pytest.raises(ValueError): + m.add_constraints(x >= np.inf, name="con_wrong_inf") + with pytest.raises(ValueError): + m.add_constraints(y <= -np.inf, name="con_wrong_neg_inf") + + +class TestConstraintCoordinateAlignment: + @pytest.fixture(params=["xarray", "pandas_series"], ids=["da", "series"]) + def subset(self, request: Any) -> xr.DataArray | pd.Series: + if request.param == "xarray": + return xr.DataArray([10.0, 30.0], dims=["dim_2"], coords={"dim_2": [1, 3]}) + return pd.Series([10.0, 30.0], index=pd.Index([1, 3], name="dim_2")) + + @pytest.fixture(params=["xarray", "pandas_series"], ids=["da", "series"]) + def superset(self, request: Any) -> xr.DataArray | pd.Series: + if request.param == "xarray": + return xr.DataArray( + np.arange(25, dtype=float), + dims=["dim_2"], + coords={"dim_2": range(25)}, + ) + return pd.Series( + np.arange(25, dtype=float), index=pd.Index(range(25), name="dim_2") + ) + + def test_var_le_subset(self, v: Variable, subset: xr.DataArray) -> None: + con = v <= subset + assert con.sizes["dim_2"] == v.sizes["dim_2"] + assert con.rhs.sel(dim_2=1).item() == 10.0 + assert con.rhs.sel(dim_2=3).item() == 30.0 + assert np.isnan(con.rhs.sel(dim_2=0).item()) + + @pytest.mark.parametrize("sign", [LESS_EQUAL, GREATER_EQUAL, EQUAL]) + def test_var_comparison_subset( + self, v: Variable, subset: xr.DataArray, sign: str + ) -> None: + if sign == LESS_EQUAL: + con = v <= subset + elif sign == GREATER_EQUAL: + con = v >= subset + else: + con = v == subset + assert con.sizes["dim_2"] == v.sizes["dim_2"] + assert con.rhs.sel(dim_2=1).item() == 10.0 + assert np.isnan(con.rhs.sel(dim_2=0).item()) + + def test_expr_le_subset(self, v: Variable, subset: xr.DataArray) -> None: + expr = v + 5 + con = expr <= subset + assert con.sizes["dim_2"] == v.sizes["dim_2"] + assert con.rhs.sel(dim_2=1).item() == pytest.approx(5.0) + assert con.rhs.sel(dim_2=3).item() == pytest.approx(25.0) + assert np.isnan(con.rhs.sel(dim_2=0).item()) + + @pytest.mark.parametrize("sign", [LESS_EQUAL, GREATER_EQUAL, EQUAL]) + def test_subset_comparison_var( + self, v: Variable, subset: xr.DataArray, sign: str + ) -> None: + if sign == LESS_EQUAL: + con = subset <= v + elif sign == GREATER_EQUAL: + con = subset >= v + else: + con = subset == v + assert con.sizes["dim_2"] == v.sizes["dim_2"] + assert np.isnan(con.rhs.sel(dim_2=0).item()) + assert con.rhs.sel(dim_2=1).item() == pytest.approx(10.0) + + @pytest.mark.parametrize("sign", [LESS_EQUAL, GREATER_EQUAL]) + def test_superset_comparison_var( + self, v: Variable, superset: xr.DataArray, sign: str + ) -> None: + if sign == LESS_EQUAL: + con = superset <= v + else: + con = superset >= v + assert con.sizes["dim_2"] == v.sizes["dim_2"] + assert not np.isnan(con.lhs.coeffs.values).any() + assert not np.isnan(con.rhs.values).any() + + def test_constraint_rhs_extra_dims_broadcasts(self, v: Variable) -> None: + rhs = xr.DataArray( + [[1.0, 2.0]], + dims=["extra", "dim_2"], + coords={"dim_2": [0, 1]}, + ) + c = v <= rhs + assert "extra" in c.dims + + def test_subset_constraint_solve_integration(self) -> None: + if not available_solvers: + pytest.skip("No solver available") + solver = "highs" if "highs" in available_solvers else available_solvers[0] + m = Model() + coords = pd.RangeIndex(5, name="i") + x = m.add_variables(lower=0, upper=100, coords=[coords], name="x") + subset_ub = xr.DataArray([10.0, 20.0], dims=["i"], coords={"i": [1, 3]}) + m.add_constraints(x <= subset_ub, name="subset_ub") + m.add_objective(x.sum(), sense="max") + m.solve(solver_name=solver) + sol = m.solution["x"] + assert sol.sel(i=1).item() == pytest.approx(10.0) + assert sol.sel(i=3).item() == pytest.approx(20.0) + assert sol.sel(i=0).item() == pytest.approx(100.0) + assert sol.sel(i=2).item() == pytest.approx(100.0) + assert sol.sel(i=4).item() == pytest.approx(100.0) diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index ed808e78..81e5737d 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -14,6 +14,7 @@ import xarray as xr from xarray.testing import assert_equal +import linopy from linopy import LinearExpression, Model, QuadraticExpression, Variable, merge from linopy.constants import HELPER_DIMS, TERM_DIM from linopy.expressions import ScalarLinearExpression @@ -21,44 +22,15 @@ from linopy.variables import ScalarVariable -@pytest.fixture -def m() -> Model: - m = Model() +@pytest.fixture(autouse=True) +def _use_exact_join(): + """Use exact arithmetic join for all tests in this module.""" + linopy.options["arithmetic_join"] = "exact" + yield + linopy.options["arithmetic_join"] = "legacy" - m.add_variables(pd.Series([0, 0]), 1, name="x") - m.add_variables(4, pd.Series([8, 10]), name="y") - m.add_variables(0, pd.DataFrame([[1, 2], [3, 4], [5, 6]]).T, name="z") - m.add_variables(coords=[pd.RangeIndex(20, name="dim_2")], name="v") - idx = pd.MultiIndex.from_product([[1, 2], ["a", "b"]], names=("level1", "level2")) - idx.name = "dim_3" - m.add_variables(coords=[idx], name="u") - return m - - -@pytest.fixture -def x(m: Model) -> Variable: - return m.variables["x"] - - -@pytest.fixture -def y(m: Model) -> Variable: - return m.variables["y"] - - -@pytest.fixture -def z(m: Model) -> Variable: - return m.variables["z"] - - -@pytest.fixture -def v(m: Model) -> Variable: - return m.variables["v"] - - -@pytest.fixture -def u(m: Model) -> Variable: - return m.variables["u"] +# Fixtures m, x, y, z, v, u are provided by conftest.py def test_empty_linexpr(m: Model) -> None: diff --git a/test/test_linear_expression_legacy.py b/test/test_linear_expression_legacy.py new file mode 100644 index 00000000..2cfb315b --- /dev/null +++ b/test/test_linear_expression_legacy.py @@ -0,0 +1,2102 @@ +# ruff: noqa: D106 +#!/usr/bin/env python3 +""" +Created on Wed Mar 17 17:06:36 2021. + +@author: fabian +""" + +from __future__ import annotations + +from typing import Any + +import numpy as np +import pandas as pd +import polars as pl +import pytest +import xarray as xr +from xarray.testing import assert_equal + +from linopy import LinearExpression, Model, QuadraticExpression, Variable, merge +from linopy.constants import HELPER_DIMS, TERM_DIM +from linopy.expressions import ScalarLinearExpression +from linopy.testing import assert_linequal, assert_quadequal +from linopy.variables import ScalarVariable + + +def test_empty_linexpr(m: Model) -> None: + LinearExpression(None, m) + + +def test_linexpr_with_wrong_data(m: Model) -> None: + with pytest.raises(ValueError): + LinearExpression(xr.Dataset({"a": [1]}), m) + + coeffs = xr.DataArray([1, 2], dims=["a"]) + vars = xr.DataArray([1, 2], dims=["a"]) + data = xr.Dataset({"coeffs": coeffs, "vars": vars}) + with pytest.raises(ValueError): + LinearExpression(data, m) + + # with model as None + coeffs = xr.DataArray(np.array([1, 2]), dims=[TERM_DIM]) + vars = xr.DataArray(np.array([1, 2]), dims=[TERM_DIM]) + data = xr.Dataset({"coeffs": coeffs, "vars": vars}) + with pytest.raises(ValueError): + LinearExpression(data, None) # type: ignore + + +def test_linexpr_with_helper_dims_as_coords(m: Model) -> None: + coords = [pd.Index([0], name="a"), pd.Index([1, 2], name=TERM_DIM)] + coeffs = xr.DataArray(np.array([[1, 2]]), coords=coords) + vars = xr.DataArray(np.array([[1, 2]]), coords=coords) + + data = xr.Dataset({"coeffs": coeffs, "vars": vars}) + assert set(HELPER_DIMS).intersection(set(data.coords)) + + expr = LinearExpression(data, m) + assert not set(HELPER_DIMS).intersection(set(expr.data.coords)) + + +def test_linexpr_with_data_without_coords(m: Model) -> None: + lhs = 1 * m["x"] + vars = xr.DataArray(lhs.vars.values, dims=["dim_0", TERM_DIM]) + coeffs = xr.DataArray(lhs.coeffs.values, dims=["dim_0", TERM_DIM]) + data = xr.Dataset({"vars": vars, "coeffs": coeffs}) + expr = LinearExpression(data, m) + assert_linequal(expr, lhs) + + +def test_linexpr_from_constant_dataarray(m: Model) -> None: + const = xr.DataArray([1, 2], dims=["dim_0"]) + expr = LinearExpression(const, m) + assert (expr.const == const).all() + assert expr.nterm == 0 + + +def test_linexpr_from_constant_pl_series(m: Model) -> None: + const = pl.Series([1, 2]) + expr = LinearExpression(const, m) + assert (expr.const == const.to_numpy()).all() + assert expr.nterm == 0 + + +def test_linexpr_from_constant_pandas_series(m: Model) -> None: + const = pd.Series([1, 2], index=pd.RangeIndex(2, name="dim_0")) + expr = LinearExpression(const, m) + assert (expr.const == const).all() + assert expr.nterm == 0 + + +def test_linexpr_from_constant_pandas_dataframe(m: Model) -> None: + const = pd.DataFrame([[1, 2], [3, 4]], columns=["a", "b"]) + expr = LinearExpression(const, m) + assert (expr.const == const).all() + assert expr.nterm == 0 + + +def test_linexpr_from_constant_numpy_array(m: Model) -> None: + const = np.array([1, 2]) + expr = LinearExpression(const, m) + assert (expr.const == const).all() + assert expr.nterm == 0 + + +def test_linexpr_from_constant_scalar(m: Model) -> None: + const = 1 + expr = LinearExpression(const, m) + assert (expr.const == const).all() + assert expr.nterm == 0 + + +def test_repr(m: Model) -> None: + expr = m.linexpr((10, "x"), (1, "y")) + expr.__repr__() + + +def test_fill_value() -> None: + isinstance(LinearExpression._fill_value, dict) + + +def test_linexpr_with_scalars(m: Model) -> None: + expr = m.linexpr((10, "x"), (1, "y")) + target = xr.DataArray( + [[10, 1], [10, 1]], coords={"dim_0": [0, 1]}, dims=["dim_0", TERM_DIM] + ) + assert_equal(expr.coeffs, target) + + +def test_linexpr_with_variables_and_constants( + m: Model, x: Variable, y: Variable +) -> None: + expr = m.linexpr((10, x), (1, y), 2) + assert (expr.const == 2).all() + + +def test_linexpr_with_series(m: Model, v: Variable) -> None: + lhs = pd.Series(np.arange(20)), v + expr = m.linexpr(lhs) + isinstance(expr, LinearExpression) + + +def test_linexpr_with_dataframe(m: Model, z: Variable) -> None: + lhs = pd.DataFrame(z.labels), z + expr = m.linexpr(lhs) + isinstance(expr, LinearExpression) + + +def test_linexpr_duplicated_index(m: Model) -> None: + expr = m.linexpr((10, "x"), (-1, "x")) + assert (expr.data._term == [0, 1]).all() + + +def test_linear_expression_with_multiplication(x: Variable) -> None: + expr = 1 * x + assert isinstance(expr, LinearExpression) + assert expr.nterm == 1 + assert len(expr.vars.dim_0) == x.shape[0] + + expr = x * 1 + assert isinstance(expr, LinearExpression) + + expr2 = x.mul(1) + assert_linequal(expr, expr2) + + expr3 = expr.mul(1) + assert_linequal(expr, expr3) + + expr = x / 1 + assert isinstance(expr, LinearExpression) + + expr = x / 1.0 + assert isinstance(expr, LinearExpression) + + expr2 = x.div(1) + assert_linequal(expr, expr2) + + expr3 = expr.div(1) + assert_linequal(expr, expr3) + + expr = np.array([1, 2]) * x + assert isinstance(expr, LinearExpression) + + expr = np.array(1) * x + assert isinstance(expr, LinearExpression) + + expr = xr.DataArray(np.array([[1, 2], [2, 3]])) * x + assert isinstance(expr, LinearExpression) + + expr = pd.Series([1, 2], index=pd.RangeIndex(2, name="dim_0")) * x + assert isinstance(expr, LinearExpression) + + quad = x * x + assert isinstance(quad, QuadraticExpression) + + with pytest.raises(TypeError): + quad * quad + + expr = x * 1 + assert isinstance(expr, LinearExpression) + assert expr.__mul__(object()) is NotImplemented + assert expr.__rmul__(object()) is NotImplemented + + +def test_linear_expression_with_addition(m: Model, x: Variable, y: Variable) -> None: + expr = 10 * x + y + assert isinstance(expr, LinearExpression) + assert_linequal(expr, m.linexpr((10, "x"), (1, "y"))) + + expr = x + 8 * y + assert isinstance(expr, LinearExpression) + assert_linequal(expr, m.linexpr((1, "x"), (8, "y"))) + + expr = x + y + assert isinstance(expr, LinearExpression) + assert_linequal(expr, m.linexpr((1, "x"), (1, "y"))) + + expr2 = x.add(y) + assert_linequal(expr, expr2) + + expr3 = (x * 1).add(y) + assert_linequal(expr, expr3) + + expr3 = x + (x * x) + assert isinstance(expr3, QuadraticExpression) + + +def test_linear_expression_with_raddition(m: Model, x: Variable) -> None: + expr = x * 1.0 + expr_2: LinearExpression = 10.0 + expr + assert isinstance(expr, LinearExpression) + expr_3: LinearExpression = expr + 10.0 + assert_linequal(expr_2, expr_3) + + +def test_linear_expression_with_subtraction(m: Model, x: Variable, y: Variable) -> None: + expr = x - y + assert isinstance(expr, LinearExpression) + assert_linequal(expr, m.linexpr((1, "x"), (-1, "y"))) + + expr2 = x.sub(y) + assert_linequal(expr, expr2) + + expr3: LinearExpression = x * 1 + expr4 = expr3.sub(y) + assert_linequal(expr, expr4) + + expr = -x - 8 * y + assert isinstance(expr, LinearExpression) + assert_linequal(expr, m.linexpr((-1, "x"), (-8, "y"))) + + +def test_linear_expression_rsubtraction(x: Variable, y: Variable) -> None: + expr = x * 1.0 + expr_2: LinearExpression = 10.0 - expr + assert isinstance(expr_2, LinearExpression) + expr_3: LinearExpression = (expr - 10.0) * -1 + assert_linequal(expr_2, expr_3) + assert expr.__rsub__(object()) is NotImplemented + + +def test_linear_expression_with_constant(m: Model, x: Variable, y: Variable) -> None: + expr = x + 1 + assert isinstance(expr, LinearExpression) + assert (expr.const == 1).all() + + expr = -x - 8 * y - 10 + assert isinstance(expr, LinearExpression) + assert (expr.const == -10).all() + assert expr.nterm == 2 + + +def test_linear_expression_with_constant_multiplication( + m: Model, x: Variable, y: Variable +) -> None: + expr = x + 1 + + obs = expr * 10 + assert isinstance(obs, LinearExpression) + assert (obs.const == 10).all() + + obs = expr * pd.Series([1, 2, 3], index=pd.RangeIndex(3, name="new_dim")) + assert isinstance(obs, LinearExpression) + assert obs.shape == (2, 3, 1) + + +def test_linear_expression_multi_indexed(u: Variable) -> None: + expr = 3 * u + 1 * u + assert isinstance(expr, LinearExpression) + + +def test_linear_expression_with_errors(m: Model, x: Variable) -> None: + with pytest.raises(TypeError): + x / x + + with pytest.raises(TypeError): + x / (1 * x) + + with pytest.raises(TypeError): + m.linexpr((10, x.labels), (1, "y")) + + with pytest.raises(TypeError): + m.linexpr(a=2) # type: ignore + + +def test_linear_expression_from_rule(m: Model, x: Variable, y: Variable) -> None: + def bound(m: Model, i: int) -> ScalarLinearExpression: + return ( + (i - 1) * x.at[i - 1] + y.at[i] + 1 * x.at[i] + if i == 1 + else i * x.at[i] - y.at[i] + ) + + expr = LinearExpression.from_rule(m, bound, x.coords) + assert isinstance(expr, LinearExpression) + assert expr.nterm == 3 + repr(expr) # test repr + + +def test_linear_expression_from_rule_with_return_none( + m: Model, x: Variable, y: Variable +) -> None: + # with return type None + def bound(m: Model, i: int) -> ScalarLinearExpression | None: + if i == 1: + return (i - 1) * x.at[i - 1] + y.at[i] + return None + + expr = LinearExpression.from_rule(m, bound, x.coords) + assert isinstance(expr, LinearExpression) + assert (expr.vars[0] == -1).all() + assert (expr.vars[1] != -1).all() + assert expr.coeffs[0].isnull().all() + assert expr.coeffs[1].notnull().all() + repr(expr) # test repr + + +def test_linear_expression_addition(x: Variable, y: Variable, z: Variable) -> None: + expr = 10 * x + y + other = 2 * y + z + res = expr + other + + assert res.nterm == expr.nterm + other.nterm + assert (res.coords["dim_0"] == expr.coords["dim_0"]).all() + assert (res.coords["dim_1"] == other.coords["dim_1"]).all() + assert res.data.notnull().all().to_array().all() + + res2 = expr.add(other) + assert_linequal(res, res2) + + assert isinstance(x - expr, LinearExpression) + assert isinstance(x + expr, LinearExpression) + + +def test_linear_expression_addition_with_constant( + x: Variable, y: Variable, z: Variable +) -> None: + expr = 10 * x + y + 10 + assert (expr.const == 10).all() + + expr = 10 * x + y + np.array([2, 3]) + assert list(expr.const) == [2, 3] + + expr = 10 * x + y + pd.Series([2, 3]) + assert list(expr.const) == [2, 3] + + +def test_linear_expression_subtraction(x: Variable, y: Variable, z: Variable) -> None: + expr = 10 * x + y - 10 + assert (expr.const == -10).all() + + expr = 10 * x + y - np.array([2, 3]) + assert list(expr.const) == [-2, -3] + + expr = 10 * x + y - pd.Series([2, 3]) + assert list(expr.const) == [-2, -3] + + +def test_linear_expression_substraction( + x: Variable, y: Variable, z: Variable, v: Variable +) -> None: + expr = 10 * x + y + other = 2 * y - z + res = expr - other + + assert res.nterm == expr.nterm + other.nterm + assert (res.coords["dim_0"] == expr.coords["dim_0"]).all() + assert (res.coords["dim_1"] == other.coords["dim_1"]).all() + assert res.data.notnull().all().to_array().all() + + +def test_linear_expression_sum( + x: Variable, y: Variable, z: Variable, v: Variable +) -> None: + expr = 10 * x + y + z + res = expr.sum("dim_0") + + assert res.size == expr.size + assert res.nterm == expr.nterm * len(expr.data.dim_0) + + res = expr.sum() + assert res.size == expr.size + assert res.nterm == expr.size + assert res.data.notnull().all().to_array().all() + + assert_linequal(expr.sum(["dim_0", TERM_DIM]), expr.sum("dim_0")) + + # test special case otherride coords + expr = v.loc[:9] + v.loc[10:] + assert expr.nterm == 2 + assert len(expr.coords["dim_2"]) == 10 + + +def test_linear_expression_sum_with_const( + x: Variable, y: Variable, z: Variable, v: Variable +) -> None: + expr = 10 * x + y + z + 10 + res = expr.sum("dim_0") + + assert res.size == expr.size + assert res.nterm == expr.nterm * len(expr.data.dim_0) + assert (res.const == 20).all() + + res = expr.sum() + assert res.size == expr.size + assert res.nterm == expr.size + assert res.data.notnull().all().to_array().all() + assert (res.const == 60).item() + + assert_linequal(expr.sum(["dim_0", TERM_DIM]), expr.sum("dim_0")) + + # test special case otherride coords + expr = v.loc[:9] + v.loc[10:] + assert expr.nterm == 2 + assert len(expr.coords["dim_2"]) == 10 + + +def test_linear_expression_sum_drop_zeros(z: Variable) -> None: + coeff = xr.zeros_like(z.labels) + coeff[1, 0] = 3 + coeff[0, 2] = 5 + expr = coeff * z + + res = expr.sum("dim_0", drop_zeros=True) + assert res.nterm == 1 + + res = expr.sum("dim_1", drop_zeros=True) + assert res.nterm == 1 + + coeff[1, 2] = 4 + expr.data["coeffs"] = coeff + res = expr.sum() + + res = expr.sum("dim_0", drop_zeros=True) + assert res.nterm == 2 + + res = expr.sum("dim_1", drop_zeros=True) + assert res.nterm == 2 + + +def test_linear_expression_sum_warn_using_dims(z: Variable) -> None: + with pytest.warns(DeprecationWarning): + (1 * z).sum(dims="dim_0") + + +def test_linear_expression_sum_warn_unknown_kwargs(z: Variable) -> None: + with pytest.raises(ValueError): + (1 * z).sum(unknown_kwarg="dim_0") + + +def test_linear_expression_power(x: Variable) -> None: + expr: LinearExpression = x * 1.0 + qd_expr = expr**2 + assert isinstance(qd_expr, QuadraticExpression) + + qd_expr2 = expr.pow(2) + assert_quadequal(qd_expr, qd_expr2) + + with pytest.raises(ValueError): + expr**3 + + +def test_linear_expression_multiplication( + x: Variable, y: Variable, z: Variable +) -> None: + expr = 10 * x + y + z + mexpr = expr * 10 + assert (mexpr.coeffs.sel(dim_1=0, dim_0=0, _term=0) == 100).item() + + mexpr = 10 * expr + assert (mexpr.coeffs.sel(dim_1=0, dim_0=0, _term=0) == 100).item() + + mexpr = expr / 100 + assert (mexpr.coeffs.sel(dim_1=0, dim_0=0, _term=0) == 1 / 10).item() + + mexpr = expr / 100.0 + assert (mexpr.coeffs.sel(dim_1=0, dim_0=0, _term=0) == 1 / 10).item() + + +def test_matmul_variable_and_const(x: Variable, y: Variable) -> None: + const = np.array([1, 2]) + expr = x @ const + assert expr.nterm == 2 + assert_linequal(expr, (x * const).sum()) + + assert_linequal(x @ const, (x * const).sum()) + + assert_linequal(x.dot(const), x @ const) + + +def test_matmul_expr_and_const(x: Variable, y: Variable) -> None: + expr = 10 * x + y + const = np.array([1, 2]) + res = expr @ const + target = (10 * x) @ const + y @ const + assert res.nterm == 4 + assert_linequal(res, target) + + assert_linequal(expr.dot(const), target) + + +def test_matmul_wrong_input(x: Variable, y: Variable, z: Variable) -> None: + expr = 10 * x + y + z + with pytest.raises(TypeError): + expr @ expr + + +def test_linear_expression_multiplication_invalid( + x: Variable, y: Variable, z: Variable +) -> None: + expr = 10 * x + y + z + + with pytest.raises(TypeError): + expr = 10 * x + y + z + expr * expr + + with pytest.raises(TypeError): + expr = 10 * x + y + z + expr / x + + +class TestCoordinateAlignment: + @pytest.fixture(params=["da", "series"]) + def subset(self, request: Any) -> xr.DataArray | pd.Series: + if request.param == "da": + return xr.DataArray([10.0, 30.0], dims=["dim_2"], coords={"dim_2": [1, 3]}) + return pd.Series([10.0, 30.0], index=pd.Index([1, 3], name="dim_2")) + + @pytest.fixture(params=["da", "series"]) + def superset(self, request: Any) -> xr.DataArray | pd.Series: + if request.param == "da": + return xr.DataArray( + np.arange(25, dtype=float), + dims=["dim_2"], + coords={"dim_2": range(25)}, + ) + return pd.Series( + np.arange(25, dtype=float), index=pd.Index(range(25), name="dim_2") + ) + + @pytest.fixture + def expected_fill(self) -> np.ndarray: + arr = np.zeros(20) + arr[1] = 10.0 + arr[3] = 30.0 + return arr + + @pytest.fixture(params=["xarray", "pandas_series"], ids=["da", "series"]) + def nan_constant(self, request: Any) -> xr.DataArray | pd.Series: + vals = np.arange(20, dtype=float) + vals[0] = np.nan + vals[5] = np.nan + vals[19] = np.nan + if request.param == "xarray": + return xr.DataArray(vals, dims=["dim_2"], coords={"dim_2": range(20)}) + return pd.Series(vals, index=pd.Index(range(20), name="dim_2")) + + class TestSubset: + @pytest.mark.parametrize("operand", ["var", "expr"]) + def test_mul_subset_fills_zeros( + self, + v: Variable, + subset: xr.DataArray, + expected_fill: np.ndarray, + operand: str, + ) -> None: + target = v if operand == "var" else 1 * v + result = target * subset + assert result.sizes["dim_2"] == v.sizes["dim_2"] + assert not np.isnan(result.coeffs.values).any() + np.testing.assert_array_equal(result.coeffs.squeeze().values, expected_fill) + + @pytest.mark.parametrize("operand", ["var", "expr"]) + def test_add_subset_fills_zeros( + self, + v: Variable, + subset: xr.DataArray, + expected_fill: np.ndarray, + operand: str, + ) -> None: + if operand == "var": + result = v + subset + expected = expected_fill + else: + result = (v + 5) + subset + expected = expected_fill + 5 + assert result.sizes["dim_2"] == v.sizes["dim_2"] + assert not np.isnan(result.const.values).any() + np.testing.assert_array_equal(result.const.values, expected) + + @pytest.mark.parametrize("operand", ["var", "expr"]) + def test_sub_subset_fills_negated( + self, + v: Variable, + subset: xr.DataArray, + expected_fill: np.ndarray, + operand: str, + ) -> None: + if operand == "var": + result = v - subset + expected = -expected_fill + else: + result = (v + 5) - subset + expected = 5 - expected_fill + assert result.sizes["dim_2"] == v.sizes["dim_2"] + assert not np.isnan(result.const.values).any() + np.testing.assert_array_equal(result.const.values, expected) + + @pytest.mark.parametrize("operand", ["var", "expr"]) + def test_div_subset_inverts_nonzero( + self, v: Variable, subset: xr.DataArray, operand: str + ) -> None: + target = v if operand == "var" else 1 * v + result = target / subset + assert result.sizes["dim_2"] == v.sizes["dim_2"] + assert not np.isnan(result.coeffs.values).any() + assert result.coeffs.squeeze().sel(dim_2=1).item() == pytest.approx(0.1) + assert result.coeffs.squeeze().sel(dim_2=0).item() == pytest.approx(1.0) + + def test_subset_add_var_coefficients( + self, v: Variable, subset: xr.DataArray + ) -> None: + result = subset + v + np.testing.assert_array_equal(result.coeffs.squeeze().values, np.ones(20)) + + def test_subset_sub_var_coefficients( + self, v: Variable, subset: xr.DataArray + ) -> None: + result = subset - v + np.testing.assert_array_equal(result.coeffs.squeeze().values, -np.ones(20)) + + class TestSuperset: + def test_add_superset_pins_to_lhs_coords( + self, v: Variable, superset: xr.DataArray + ) -> None: + result = v + superset + assert result.sizes["dim_2"] == v.sizes["dim_2"] + assert not np.isnan(result.const.values).any() + + def test_add_var_commutative(self, v: Variable, superset: xr.DataArray) -> None: + assert_linequal(superset + v, v + superset) + + def test_sub_var_commutative(self, v: Variable, superset: xr.DataArray) -> None: + assert_linequal(superset - v, -v + superset) + + def test_mul_var_commutative(self, v: Variable, superset: xr.DataArray) -> None: + assert_linequal(superset * v, v * superset) + + def test_mul_superset_pins_to_lhs_coords( + self, v: Variable, superset: xr.DataArray + ) -> None: + result = v * superset + assert result.sizes["dim_2"] == v.sizes["dim_2"] + assert not np.isnan(result.coeffs.values).any() + + def test_div_superset_pins_to_lhs_coords(self, v: Variable) -> None: + superset_nonzero = xr.DataArray( + np.arange(1, 26, dtype=float), + dims=["dim_2"], + coords={"dim_2": range(25)}, + ) + result = v / superset_nonzero + assert result.sizes["dim_2"] == v.sizes["dim_2"] + assert not np.isnan(result.coeffs.values).any() + + class TestDisjoint: + def test_add_disjoint_fills_zeros(self, v: Variable) -> None: + disjoint = xr.DataArray( + [100.0, 200.0], dims=["dim_2"], coords={"dim_2": [50, 60]} + ) + result = v + disjoint + assert result.sizes["dim_2"] == v.sizes["dim_2"] + assert not np.isnan(result.const.values).any() + np.testing.assert_array_equal(result.const.values, np.zeros(20)) + + def test_mul_disjoint_fills_zeros(self, v: Variable) -> None: + disjoint = xr.DataArray( + [10.0, 20.0], dims=["dim_2"], coords={"dim_2": [50, 60]} + ) + result = v * disjoint + assert result.sizes["dim_2"] == v.sizes["dim_2"] + assert not np.isnan(result.coeffs.values).any() + np.testing.assert_array_equal(result.coeffs.squeeze().values, np.zeros(20)) + + def test_div_disjoint_preserves_coeffs(self, v: Variable) -> None: + disjoint = xr.DataArray( + [10.0, 20.0], dims=["dim_2"], coords={"dim_2": [50, 60]} + ) + result = v / disjoint + assert result.sizes["dim_2"] == v.sizes["dim_2"] + assert not np.isnan(result.coeffs.values).any() + np.testing.assert_array_equal(result.coeffs.squeeze().values, np.ones(20)) + + class TestCommutativity: + @pytest.mark.parametrize( + "make_lhs,make_rhs", + [ + (lambda v, s: s * v, lambda v, s: v * s), + (lambda v, s: s * (1 * v), lambda v, s: (1 * v) * s), + (lambda v, s: s + v, lambda v, s: v + s), + (lambda v, s: s + (v + 5), lambda v, s: (v + 5) + s), + ], + ids=["subset*var", "subset*expr", "subset+var", "subset+expr"], + ) + def test_commutativity( + self, + v: Variable, + subset: xr.DataArray, + make_lhs: Any, + make_rhs: Any, + ) -> None: + assert_linequal(make_lhs(v, subset), make_rhs(v, subset)) + + def test_sub_var_anticommutative( + self, v: Variable, subset: xr.DataArray + ) -> None: + assert_linequal(subset - v, -v + subset) + + def test_sub_expr_anticommutative( + self, v: Variable, subset: xr.DataArray + ) -> None: + expr = v + 5 + assert_linequal(subset - expr, -(expr - subset)) + + def test_add_commutativity_full_coords(self, v: Variable) -> None: + full = xr.DataArray( + np.arange(20, dtype=float), + dims=["dim_2"], + coords={"dim_2": range(20)}, + ) + assert_linequal(v + full, full + v) + + class TestQuadratic: + def test_quadexpr_add_subset( + self, + v: Variable, + subset: xr.DataArray, + expected_fill: np.ndarray, + ) -> None: + qexpr = v * v + result = qexpr + subset + assert isinstance(result, QuadraticExpression) + assert result.sizes["dim_2"] == v.sizes["dim_2"] + assert not np.isnan(result.const.values).any() + np.testing.assert_array_equal(result.const.values, expected_fill) + + def test_quadexpr_sub_subset( + self, + v: Variable, + subset: xr.DataArray, + expected_fill: np.ndarray, + ) -> None: + qexpr = v * v + result = qexpr - subset + assert isinstance(result, QuadraticExpression) + assert result.sizes["dim_2"] == v.sizes["dim_2"] + assert not np.isnan(result.const.values).any() + np.testing.assert_array_equal(result.const.values, -expected_fill) + + def test_quadexpr_mul_subset( + self, + v: Variable, + subset: xr.DataArray, + expected_fill: np.ndarray, + ) -> None: + qexpr = v * v + result = qexpr * subset + assert isinstance(result, QuadraticExpression) + assert result.sizes["dim_2"] == v.sizes["dim_2"] + assert not np.isnan(result.coeffs.values).any() + np.testing.assert_array_equal(result.coeffs.squeeze().values, expected_fill) + + def test_subset_mul_quadexpr( + self, + v: Variable, + subset: xr.DataArray, + expected_fill: np.ndarray, + ) -> None: + qexpr = v * v + result = subset * qexpr + assert isinstance(result, QuadraticExpression) + assert result.sizes["dim_2"] == v.sizes["dim_2"] + assert not np.isnan(result.coeffs.values).any() + np.testing.assert_array_equal(result.coeffs.squeeze().values, expected_fill) + + def test_subset_add_quadexpr(self, v: Variable, subset: xr.DataArray) -> None: + qexpr = v * v + assert_quadequal(subset + qexpr, qexpr + subset) + + class TestMissingValues: + """Same shape as variable but with NaN entries in the constant.""" + + EXPECTED_NAN_MASK = np.zeros(20, dtype=bool) + EXPECTED_NAN_MASK[[0, 5, 19]] = True + + @pytest.mark.parametrize("operand", ["var", "expr"]) + def test_add_nan_propagates( + self, + v: Variable, + nan_constant: xr.DataArray | pd.Series, + operand: str, + ) -> None: + target = v if operand == "var" else v + 5 + result = target + nan_constant + assert result.sizes["dim_2"] == 20 + np.testing.assert_array_equal( + np.isnan(result.const.values), self.EXPECTED_NAN_MASK + ) + + @pytest.mark.parametrize("operand", ["var", "expr"]) + def test_sub_nan_propagates( + self, + v: Variable, + nan_constant: xr.DataArray | pd.Series, + operand: str, + ) -> None: + target = v if operand == "var" else v + 5 + result = target - nan_constant + assert result.sizes["dim_2"] == 20 + np.testing.assert_array_equal( + np.isnan(result.const.values), self.EXPECTED_NAN_MASK + ) + + @pytest.mark.parametrize("operand", ["var", "expr"]) + def test_mul_nan_propagates( + self, + v: Variable, + nan_constant: xr.DataArray | pd.Series, + operand: str, + ) -> None: + target = v if operand == "var" else 1 * v + result = target * nan_constant + assert result.sizes["dim_2"] == 20 + np.testing.assert_array_equal( + np.isnan(result.coeffs.squeeze().values), self.EXPECTED_NAN_MASK + ) + + @pytest.mark.parametrize("operand", ["var", "expr"]) + def test_div_nan_propagates( + self, + v: Variable, + nan_constant: xr.DataArray | pd.Series, + operand: str, + ) -> None: + target = v if operand == "var" else 1 * v + result = target / nan_constant + assert result.sizes["dim_2"] == 20 + np.testing.assert_array_equal( + np.isnan(result.coeffs.squeeze().values), self.EXPECTED_NAN_MASK + ) + + def test_add_commutativity( + self, + v: Variable, + nan_constant: xr.DataArray | pd.Series, + ) -> None: + result_a = v + nan_constant + result_b = nan_constant + v + # Compare non-NaN values are equal and NaN positions match + nan_mask_a = np.isnan(result_a.const.values) + nan_mask_b = np.isnan(result_b.const.values) + np.testing.assert_array_equal(nan_mask_a, nan_mask_b) + np.testing.assert_array_equal( + result_a.const.values[~nan_mask_a], + result_b.const.values[~nan_mask_b], + ) + np.testing.assert_array_equal( + result_a.coeffs.values, result_b.coeffs.values + ) + + def test_mul_commutativity( + self, + v: Variable, + nan_constant: xr.DataArray | pd.Series, + ) -> None: + result_a = v * nan_constant + result_b = nan_constant * v + nan_mask_a = np.isnan(result_a.coeffs.values) + nan_mask_b = np.isnan(result_b.coeffs.values) + np.testing.assert_array_equal(nan_mask_a, nan_mask_b) + np.testing.assert_array_equal( + result_a.coeffs.values[~nan_mask_a], + result_b.coeffs.values[~nan_mask_b], + ) + + def test_quadexpr_add_nan( + self, + v: Variable, + nan_constant: xr.DataArray | pd.Series, + ) -> None: + qexpr = v * v + result = qexpr + nan_constant + assert isinstance(result, QuadraticExpression) + assert result.sizes["dim_2"] == 20 + np.testing.assert_array_equal( + np.isnan(result.const.values), self.EXPECTED_NAN_MASK + ) + + class TestMultiDim: + def test_multidim_subset_mul(self, m: Model) -> None: + coords_a = pd.RangeIndex(4, name="a") + coords_b = pd.RangeIndex(5, name="b") + w = m.add_variables(coords=[coords_a, coords_b], name="w") + + subset_2d = xr.DataArray( + [[2.0, 3.0], [4.0, 5.0]], + dims=["a", "b"], + coords={"a": [1, 3], "b": [0, 4]}, + ) + result = w * subset_2d + assert result.sizes["a"] == 4 + assert result.sizes["b"] == 5 + assert not np.isnan(result.coeffs.values).any() + assert result.coeffs.squeeze().sel(a=1, b=0).item() == pytest.approx(2.0) + assert result.coeffs.squeeze().sel(a=3, b=4).item() == pytest.approx(5.0) + assert result.coeffs.squeeze().sel(a=0, b=0).item() == pytest.approx(0.0) + assert result.coeffs.squeeze().sel(a=1, b=2).item() == pytest.approx(0.0) + + def test_multidim_subset_add(self, m: Model) -> None: + coords_a = pd.RangeIndex(4, name="a") + coords_b = pd.RangeIndex(5, name="b") + w = m.add_variables(coords=[coords_a, coords_b], name="w") + + subset_2d = xr.DataArray( + [[2.0, 3.0], [4.0, 5.0]], + dims=["a", "b"], + coords={"a": [1, 3], "b": [0, 4]}, + ) + result = w + subset_2d + assert result.sizes["a"] == 4 + assert result.sizes["b"] == 5 + assert not np.isnan(result.const.values).any() + assert result.const.sel(a=1, b=0).item() == pytest.approx(2.0) + assert result.const.sel(a=3, b=4).item() == pytest.approx(5.0) + assert result.const.sel(a=0, b=0).item() == pytest.approx(0.0) + + class TestXarrayCompat: + def test_da_eq_da_still_works(self) -> None: + da1 = xr.DataArray([1, 2, 3]) + da2 = xr.DataArray([1, 2, 3]) + result = da1 == da2 + assert result.values.all() + + def test_da_eq_scalar_still_works(self) -> None: + da = xr.DataArray([1, 2, 3]) + result = da == 2 + np.testing.assert_array_equal(result.values, [False, True, False]) + + def test_da_truediv_var_raises(self, v: Variable) -> None: + da = xr.DataArray(np.ones(20), dims=["dim_2"], coords={"dim_2": range(20)}) + with pytest.raises(TypeError): + da / v # type: ignore[operator] + + +def test_expression_inherited_properties(x: Variable, y: Variable) -> None: + expr = 10 * x + y + assert isinstance(expr.attrs, dict) + assert isinstance(expr.coords, xr.Coordinates) + assert isinstance(expr.indexes, xr.core.indexes.Indexes) + assert isinstance(expr.sizes, xr.core.utils.Frozen) + + +def test_linear_expression_getitem_single(x: Variable, y: Variable) -> None: + expr = 10 * x + y + 3 + sel = expr[0] + assert isinstance(sel, LinearExpression) + assert sel.nterm == 2 + # one expression with two terms (constant is not counted) + assert sel.size == 2 + + +def test_linear_expression_getitem_slice(x: Variable, y: Variable) -> None: + expr = 10 * x + y + 3 + sel = expr[:1] + + assert isinstance(sel, LinearExpression) + assert sel.nterm == 2 + # one expression with two terms (constant is not counted) + assert sel.size == 2 + + +def test_linear_expression_getitem_list(x: Variable, y: Variable, z: Variable) -> None: + expr = 10 * x + z + 10 + sel = expr[:, [0, 2]] + assert isinstance(sel, LinearExpression) + assert sel.nterm == 2 + # four expressions with two terms (constant is not counted) + assert sel.size == 8 + + +def test_linear_expression_loc(x: Variable, y: Variable) -> None: + expr = x + y + assert expr.loc[0].size < expr.loc[:5].size + + +def test_linear_expression_empty(v: Variable) -> None: + expr = 7 * v + assert not expr.empty + assert expr.loc[[]].empty + + with pytest.warns(DeprecationWarning, match="use `.empty` property instead"): + assert expr.loc[[]].empty() + + +def test_linear_expression_isnull(v: Variable) -> None: + expr = np.arange(20) * v + filter = (expr.coeffs >= 10).any(TERM_DIM) + expr = expr.where(filter) + assert expr.isnull().sum() == 10 + + +def test_linear_expression_flat(v: Variable) -> None: + coeff = np.arange(1, 21) # use non-zero coefficients + expr = coeff * v + df = expr.flat + assert isinstance(df, pd.DataFrame) + assert (df.coeffs == coeff).all() + + +def test_iterate_slices(x: Variable, y: Variable) -> None: + expr = x + 10 * y + for s in expr.iterate_slices(slice_size=2): + assert isinstance(s, LinearExpression) + assert s.nterm == expr.nterm + assert s.coord_dims == expr.coord_dims + + +def test_linear_expression_to_polars(v: Variable) -> None: + coeff = np.arange(1, 21) # use non-zero coefficients + expr = coeff * v + df = expr.to_polars() + assert isinstance(df, pl.DataFrame) + assert (df["coeffs"].to_numpy() == coeff).all() + + +def test_linear_expression_where(v: Variable) -> None: + expr = np.arange(20) * v + filter = (expr.coeffs >= 10).any(TERM_DIM) + expr = expr.where(filter) + assert isinstance(expr, LinearExpression) + assert expr.nterm == 1 + + expr = np.arange(20) * v + expr = expr.where(filter, drop=True).sum() + assert isinstance(expr, LinearExpression) + assert expr.nterm == 10 + + +def test_linear_expression_where_with_const(v: Variable) -> None: + expr = np.arange(20) * v + 10 + filter = (expr.coeffs >= 10).any(TERM_DIM) + expr = expr.where(filter) + assert isinstance(expr, LinearExpression) + assert expr.nterm == 1 + assert expr.const[:10].isnull().all() + assert (expr.const[10:] == 10).all() + + expr = np.arange(20) * v + 10 + expr = expr.where(filter, drop=True).sum() + assert isinstance(expr, LinearExpression) + assert expr.nterm == 10 + assert expr.const == 100 + + +def test_linear_expression_where_scalar_fill_value(v: Variable) -> None: + expr = np.arange(20) * v + 10 + filter = (expr.coeffs >= 10).any(TERM_DIM) + expr = expr.where(filter, 200) + assert isinstance(expr, LinearExpression) + assert expr.nterm == 1 + assert (expr.const[:10] == 200).all() + assert (expr.const[10:] == 10).all() + + +def test_linear_expression_where_array_fill_value(v: Variable) -> None: + expr = np.arange(20) * v + 10 + filter = (expr.coeffs >= 10).any(TERM_DIM) + other = expr.coeffs + expr = expr.where(filter, other) + assert isinstance(expr, LinearExpression) + assert expr.nterm == 1 + assert (expr.const[:10] == other[:10]).all() + assert (expr.const[10:] == 10).all() + + +def test_linear_expression_where_expr_fill_value(v: Variable) -> None: + expr = np.arange(20) * v + 10 + expr2 = np.arange(20) * v + 5 + filter = (expr.coeffs >= 10).any(TERM_DIM) + res = expr.where(filter, expr2) + assert isinstance(res, LinearExpression) + assert res.nterm == 1 + assert (res.const[:10] == expr2.const[:10]).all() + assert (res.const[10:] == 10).all() + + +def test_where_with_helper_dim_false(v: Variable) -> None: + expr = np.arange(20) * v + with pytest.raises(ValueError): + filter = expr.coeffs >= 10 + expr.where(filter) + + +def test_linear_expression_shift(v: Variable) -> None: + shifted = v.to_linexpr().shift(dim_2=2) + assert shifted.nterm == 1 + assert shifted.coeffs.loc[:1].isnull().all() + assert (shifted.vars.loc[:1] == -1).all() + + +def test_linear_expression_swap_dims(v: Variable) -> None: + expr = v.to_linexpr() + expr = expr.assign_coords({"second": ("dim_2", expr.indexes["dim_2"] + 100)}) + expr = expr.swap_dims({"dim_2": "second"}) + assert isinstance(expr, LinearExpression) + assert expr.coord_dims == ("second",) + + +def test_linear_expression_set_index(v: Variable) -> None: + expr = v.to_linexpr() + expr = expr.assign_coords({"second": ("dim_2", expr.indexes["dim_2"] + 100)}) + expr = expr.set_index({"multi": ["dim_2", "second"]}) + assert isinstance(expr, LinearExpression) + assert expr.coord_dims == ("multi",) + assert isinstance(expr.indexes["multi"], pd.MultiIndex) + + +def test_linear_expression_fillna(v: Variable) -> None: + expr = np.arange(20) * v + 10 + assert expr.const.sum() == 200 + + filter = (expr.coeffs >= 10).any(TERM_DIM) + filtered = expr.where(filter) + assert isinstance(filtered, LinearExpression) + assert filtered.const.sum() == 100 + + filled = filtered.fillna(10) + assert isinstance(filled, LinearExpression) + assert filled.const.sum() == 200 + assert filled.coeffs.isnull().sum() == 10 + + +def test_variable_expand_dims(v: Variable) -> None: + result = v.to_linexpr().expand_dims("new_dim") + assert isinstance(result, LinearExpression) + assert result.coord_dims == ("dim_2", "new_dim") + + +def test_variable_stack(v: Variable) -> None: + result = v.to_linexpr().expand_dims("new_dim").stack(new=("new_dim", "dim_2")) + assert isinstance(result, LinearExpression) + assert result.coord_dims == ("new",) + + +def test_linear_expression_unstack(v: Variable) -> None: + result = v.to_linexpr().expand_dims("new_dim").stack(new=("new_dim", "dim_2")) + result = result.unstack("new") + assert isinstance(result, LinearExpression) + assert result.coord_dims == ("new_dim", "dim_2") + + +def test_linear_expression_diff(v: Variable) -> None: + diff = v.to_linexpr().diff("dim_2") + assert diff.nterm == 2 + + +@pytest.mark.parametrize("use_fallback", [True, False]) +def test_linear_expression_groupby(v: Variable, use_fallback: bool) -> None: + expr = 1 * v + dim = v.dims[0] + groups = xr.DataArray([1] * 10 + [2] * 10, coords=v.coords, name=dim) + grouped = expr.groupby(groups).sum(use_fallback=use_fallback) + assert dim in grouped.dims + assert (grouped.data[dim] == [1, 2]).all() + assert grouped.nterm == 10 + + +@pytest.mark.parametrize("use_fallback", [True, False]) +def test_linear_expression_groupby_on_same_name_as_target_dim( + v: Variable, use_fallback: bool +) -> None: + expr = 1 * v + groups = xr.DataArray([1] * 10 + [2] * 10, coords=v.coords) + grouped = expr.groupby(groups).sum(use_fallback=use_fallback) + assert "group" in grouped.dims + assert (grouped.data.group == [1, 2]).all() + assert grouped.nterm == 10 + + +@pytest.mark.parametrize("use_fallback", [True]) +def test_linear_expression_groupby_ndim(z: Variable, use_fallback: bool) -> None: + # TODO: implement fallback for n-dim groupby, see https://github.com/PyPSA/linopy/issues/299 + expr = 1 * z + groups = xr.DataArray([[1, 1, 2], [1, 3, 3]], coords=z.coords) + grouped = expr.groupby(groups).sum(use_fallback=use_fallback) + assert "group" in grouped.dims + # there are three groups, 1, 2 and 3, the largest group has 3 elements + assert (grouped.data.group == [1, 2, 3]).all() + assert grouped.nterm == 3 + + +@pytest.mark.parametrize("use_fallback", [True, False]) +def test_linear_expression_groupby_with_name(v: Variable, use_fallback: bool) -> None: + expr = 1 * v + groups = xr.DataArray([1] * 10 + [2] * 10, coords=v.coords, name="my_group") + grouped = expr.groupby(groups).sum(use_fallback=use_fallback) + assert "my_group" in grouped.dims + assert (grouped.data.my_group == [1, 2]).all() + assert grouped.nterm == 10 + + +@pytest.mark.parametrize("use_fallback", [True, False]) +def test_linear_expression_groupby_with_series(v: Variable, use_fallback: bool) -> None: + expr = 1 * v + groups = pd.Series([1] * 10 + [2] * 10, index=v.indexes["dim_2"]) + grouped = expr.groupby(groups).sum(use_fallback=use_fallback) + assert "group" in grouped.dims + assert (grouped.data.group == [1, 2]).all() + assert grouped.nterm == 10 + + +@pytest.mark.parametrize("use_fallback", [True, False]) +def test_linear_expression_groupby_series_with_name( + v: Variable, use_fallback: bool +) -> None: + expr = 1 * v + groups = pd.Series([1] * 10 + [2] * 10, index=v.indexes[v.dims[0]], name="my_group") + grouped = expr.groupby(groups).sum(use_fallback=use_fallback) + assert "my_group" in grouped.dims + assert (grouped.data.my_group == [1, 2]).all() + assert grouped.nterm == 10 + + +@pytest.mark.parametrize("use_fallback", [True, False]) +def test_linear_expression_groupby_with_series_with_same_group_name( + v: Variable, use_fallback: bool +) -> None: + """ + Test that the group by works with a series whose name is the same as + the dimension to group. + """ + expr = 1 * v + groups = pd.Series([1] * 10 + [2] * 10, index=v.indexes["dim_2"]) + groups.name = "dim_2" + grouped = expr.groupby(groups).sum(use_fallback=use_fallback) + assert "dim_2" in grouped.dims + assert (grouped.data.dim_2 == [1, 2]).all() + assert grouped.nterm == 10 + + +@pytest.mark.parametrize("use_fallback", [True, False]) +def test_linear_expression_groupby_with_series_on_multiindex( + u: Variable, use_fallback: bool +) -> None: + expr = 1 * u + len_grouped_dim = len(u.data["dim_3"]) + groups = pd.Series([1] * len_grouped_dim, index=u.indexes["dim_3"]) + grouped = expr.groupby(groups).sum(use_fallback=use_fallback) + assert "group" in grouped.dims + assert (grouped.data.group == [1]).all() + assert grouped.nterm == len_grouped_dim + + +@pytest.mark.parametrize("use_fallback", [True, False]) +def test_linear_expression_groupby_with_dataframe( + v: Variable, use_fallback: bool +) -> None: + expr = 1 * v + groups = pd.DataFrame( + {"a": [1] * 10 + [2] * 10, "b": list(range(4)) * 5}, index=v.indexes["dim_2"] + ) + if use_fallback: + with pytest.raises(ValueError): + expr.groupby(groups).sum(use_fallback=use_fallback) + return + + grouped = expr.groupby(groups).sum(use_fallback=use_fallback) + index = pd.MultiIndex.from_frame(groups) + assert "group" in grouped.dims + assert set(grouped.data.group.values) == set(index.values) + assert grouped.nterm == 3 + + +@pytest.mark.parametrize("use_fallback", [True, False]) +def test_linear_expression_groupby_with_dataframe_with_same_group_name( + v: Variable, use_fallback: bool +) -> None: + """ + Test that the group by works with a dataframe whose column name is the same as + the dimension to group. + """ + expr = 1 * v + groups = pd.DataFrame( + {"dim_2": [1] * 10 + [2] * 10, "b": list(range(4)) * 5}, + index=v.indexes["dim_2"], + ) + if use_fallback: + with pytest.raises(ValueError): + expr.groupby(groups).sum(use_fallback=use_fallback) + return + + grouped = expr.groupby(groups).sum(use_fallback=use_fallback) + index = pd.MultiIndex.from_frame(groups) + assert "group" in grouped.dims + assert set(grouped.data.group.values) == set(index.values) + assert grouped.nterm == 3 + + +@pytest.mark.parametrize("use_fallback", [True, False]) +def test_linear_expression_groupby_with_dataframe_on_multiindex( + u: Variable, use_fallback: bool +) -> None: + expr = 1 * u + len_grouped_dim = len(u.data["dim_3"]) + groups = pd.DataFrame({"a": [1] * len_grouped_dim}, index=u.indexes["dim_3"]) + + if use_fallback: + with pytest.raises(ValueError): + expr.groupby(groups).sum(use_fallback=use_fallback) + return + grouped = expr.groupby(groups).sum(use_fallback=use_fallback) + assert "group" in grouped.dims + assert isinstance(grouped.indexes["group"], pd.MultiIndex) + assert grouped.nterm == len_grouped_dim + + +@pytest.mark.parametrize("use_fallback", [True, False]) +def test_linear_expression_groupby_with_dataarray( + v: Variable, use_fallback: bool +) -> None: + expr = 1 * v + df = pd.DataFrame( + {"a": [1] * 10 + [2] * 10, "b": list(range(4)) * 5}, index=v.indexes["dim_2"] + ) + groups = xr.DataArray(df) + + # this should not be the case, see https://github.com/PyPSA/linopy/issues/351 + if use_fallback: + with pytest.raises((KeyError, IndexError)): + expr.groupby(groups).sum(use_fallback=use_fallback) + return + + grouped = expr.groupby(groups).sum(use_fallback=use_fallback) + index = pd.MultiIndex.from_frame(df) + assert "group" in grouped.dims + assert set(grouped.data.group.values) == set(index.values) + assert grouped.nterm == 3 + + +def test_linear_expression_groupby_with_dataframe_non_aligned(v: Variable) -> None: + expr = 1 * v + groups = pd.DataFrame( + {"a": [1] * 10 + [2] * 10, "b": list(range(4)) * 5}, index=v.indexes["dim_2"] + ) + target = expr.groupby(groups).sum() + + groups_non_aligned = groups[::-1] + grouped = expr.groupby(groups_non_aligned).sum() + assert_linequal(grouped, target) + + +@pytest.mark.parametrize("use_fallback", [True, False]) +def test_linear_expression_groupby_with_const(v: Variable, use_fallback: bool) -> None: + expr = 1 * v + 15 + groups = xr.DataArray([1] * 10 + [2] * 10, coords=v.coords) + grouped = expr.groupby(groups).sum(use_fallback=use_fallback) + assert "group" in grouped.dims + assert (grouped.data.group == [1, 2]).all() + assert grouped.nterm == 10 + assert (grouped.const == 150).all() + + +@pytest.mark.parametrize("use_fallback", [True, False]) +def test_linear_expression_groupby_asymmetric(v: Variable, use_fallback: bool) -> None: + expr = 1 * v + # now asymetric groups which result in different nterms + groups = xr.DataArray([1] * 12 + [2] * 8, coords=v.coords) + grouped = expr.groupby(groups).sum(use_fallback=use_fallback) + assert "group" in grouped.dims + # first group must be full with vars + assert (grouped.data.sel(group=1) > 0).all() + # the last 4 entries of the second group must be empty, i.e. -1 + assert (grouped.data.sel(group=2).isel(_term=slice(None, -4)).vars >= 0).all() + assert (grouped.data.sel(group=2).isel(_term=slice(-4, None)).vars == -1).all() + assert grouped.nterm == 12 + + +@pytest.mark.parametrize("use_fallback", [True, False]) +def test_linear_expression_groupby_asymmetric_with_const( + v: Variable, use_fallback: bool +) -> None: + expr = 1 * v + 15 + # now asymetric groups which result in different nterms + groups = xr.DataArray([1] * 12 + [2] * 8, coords=v.coords) + grouped = expr.groupby(groups).sum(use_fallback=use_fallback) + assert "group" in grouped.dims + # first group must be full with vars + assert (grouped.data.sel(group=1) > 0).all() + # the last 4 entries of the second group must be empty, i.e. -1 + assert (grouped.data.sel(group=2).isel(_term=slice(None, -4)).vars >= 0).all() + assert (grouped.data.sel(group=2).isel(_term=slice(-4, None)).vars == -1).all() + assert grouped.nterm == 12 + assert list(grouped.const) == [180, 120] + + +def test_linear_expression_groupby_roll(v: Variable) -> None: + expr = 1 * v + groups = xr.DataArray([1] * 10 + [2] * 10, coords=v.coords) + grouped = expr.groupby(groups).roll(dim_2=1) + assert grouped.nterm == 1 + assert grouped.vars[0].item() == 19 + + +def test_linear_expression_groupby_roll_with_const(v: Variable) -> None: + expr = 1 * v + np.arange(20) + groups = xr.DataArray([1] * 10 + [2] * 10, coords=v.coords) + grouped = expr.groupby(groups).roll(dim_2=1) + assert grouped.nterm == 1 + assert grouped.vars[0].item() == 19 + assert grouped.const[0].item() == 9 + + +def test_linear_expression_groupby_from_variable(v: Variable) -> None: + groups = xr.DataArray([1] * 10 + [2] * 10, coords=v.coords) + grouped = v.groupby(groups).sum() + assert "group" in grouped.dims + assert (grouped.data.group == [1, 2]).all() + assert grouped.nterm == 10 + + +def test_linear_expression_rolling(v: Variable) -> None: + expr = 1 * v + rolled = expr.rolling(dim_2=2).sum() + assert rolled.nterm == 2 + + rolled = expr.rolling(dim_2=3).sum() + assert rolled.nterm == 3 + + with pytest.raises(ValueError): + expr.rolling().sum() + + +def test_linear_expression_rolling_with_const(v: Variable) -> None: + expr = 1 * v + 15 + rolled = expr.rolling(dim_2=2).sum() + assert rolled.nterm == 2 + assert rolled.const[0].item() == 15 + assert (rolled.const[1:] == 30).all() + + rolled = expr.rolling(dim_2=3).sum() + assert rolled.nterm == 3 + assert rolled.const[0].item() == 15 + assert rolled.const[1].item() == 30 + assert (rolled.const[2:] == 45).all() + + +def test_linear_expression_rolling_from_variable(v: Variable) -> None: + rolled = v.rolling(dim_2=2).sum() + assert rolled.nterm == 2 + + +def test_linear_expression_from_tuples(x: Variable, y: Variable) -> None: + expr = LinearExpression.from_tuples((10, x), (1, y)) + assert isinstance(expr, LinearExpression) + + with pytest.warns(DeprecationWarning): + expr2 = LinearExpression.from_tuples((10, x), (1,)) + assert isinstance(expr2, LinearExpression) + assert (expr2.const == 1).all() + + expr3 = LinearExpression.from_tuples((10, x), 1) + assert isinstance(expr3, LinearExpression) + assert_linequal(expr2, expr3) + + expr4 = LinearExpression.from_tuples((10, x), (1, y), 1) + assert isinstance(expr4, LinearExpression) + assert (expr4.const == 1).all() + + expr5 = LinearExpression.from_tuples(1, model=x.model) + assert isinstance(expr5, LinearExpression) + + +def test_linear_expression_from_tuples_bad_calls( + m: Model, x: Variable, y: Variable +) -> None: + with pytest.raises(ValueError): + LinearExpression.from_tuples((10, x), (1, y), x) + + with pytest.raises(ValueError): + LinearExpression.from_tuples((10, x, 3), (1, y), 1) + + sv = ScalarVariable(label=0, model=m) + with pytest.raises(TypeError): + LinearExpression.from_tuples((np.array([1, 1]), sv)) + + with pytest.raises(TypeError): + LinearExpression.from_tuples((x, x)) + + with pytest.raises(ValueError): + LinearExpression.from_tuples(10) + + +def test_linear_expression_from_constant_scalar(m: Model) -> None: + expr = LinearExpression.from_constant(model=m, constant=10) + assert expr.is_constant + assert isinstance(expr, LinearExpression) + assert (expr.const == 10).all() + + +def test_linear_expression_from_constant_1D(m: Model) -> None: + arr = pd.Series(index=pd.Index([0, 1], name="t"), data=[10, 20]) + expr = LinearExpression.from_constant(model=m, constant=arr) + assert isinstance(expr, LinearExpression) + assert list(expr.coords.keys())[0] == "t" + assert expr.nterm == 0 + assert (expr.const.values == [10, 20]).all() + assert expr.is_constant + + +def test_constant_linear_expression_to_polars_2D(m: Model) -> None: + index_a = pd.Index([0, 1], name="a") + index_b = pd.Index([0, 1, 2], name="b") + arr = np.array([[10, 20, 30], [40, 50, 60]]) + const = xr.DataArray(data=arr, coords=[index_a, index_b]) + + le_variable = m.add_variables(name="var", coords=[index_a, index_b]) * 1 + const + assert not le_variable.is_constant + le_const = LinearExpression.from_constant(model=m, constant=const) + assert le_const.is_constant + + var_pol = le_variable.to_polars() + const_pol = le_const.to_polars() + assert var_pol.shape == const_pol.shape + assert var_pol.columns == const_pol.columns + assert all(const_pol["const"] == var_pol["const"]) + assert all(const_pol["coeffs"].is_null()) + assert all(const_pol["vars"].is_null()) + + +def test_linear_expression_sanitize(x: Variable, y: Variable, z: Variable) -> None: + expr = 10 * x + y + z + assert isinstance(expr.sanitize(), LinearExpression) + + +def test_merge(x: Variable, y: Variable, z: Variable) -> None: + expr1 = (10 * x + y).sum("dim_0") + expr2 = z.sum("dim_0") + + res = merge([expr1, expr2], cls=LinearExpression) + assert res.nterm == 6 + + res: LinearExpression = merge([expr1, expr2]) # type: ignore + assert isinstance(res, LinearExpression) + + # now concat with same length of terms + expr1 = z.sel(dim_0=0).sum("dim_1") + expr2 = z.sel(dim_0=1).sum("dim_1") + + res = merge([expr1, expr2], dim="dim_1", cls=LinearExpression) + assert res.nterm == 3 + + # now with different length of terms + expr1 = z.sel(dim_0=0, dim_1=slice(0, 1)).sum("dim_1") + expr2 = z.sel(dim_0=1).sum("dim_1") + + res = merge([expr1, expr2], dim="dim_1", cls=LinearExpression) + assert res.nterm == 3 + assert res.sel(dim_1=0).vars[2].item() == -1 + + with pytest.warns(DeprecationWarning): + merge(expr1, expr2) + + +def test_linear_expression_outer_sum(x: Variable, y: Variable) -> None: + expr = x + y + expr2: LinearExpression = sum([x, y]) # type: ignore + assert_linequal(expr, expr2) + + expr = 1 * x + 2 * y + expr2: LinearExpression = sum([1 * x, 2 * y]) # type: ignore + assert_linequal(expr, expr2) + + assert isinstance(expr.sum(), LinearExpression) + + +def test_rename(x: Variable, y: Variable, z: Variable) -> None: + expr = 10 * x + y + z + renamed = expr.rename({"dim_0": "dim_5"}) + assert set(renamed.dims) == {"dim_1", "dim_5", TERM_DIM} + assert renamed.nterm == 3 + + renamed = expr.rename({"dim_0": "dim_1", "dim_1": "dim_2"}) + assert set(renamed.dims) == {"dim_1", "dim_2", TERM_DIM} + assert renamed.nterm == 3 + + +@pytest.mark.parametrize("multiple", [1.0, 0.5, 2.0, 0.0]) +def test_cumsum(m: Model, multiple: float) -> None: + # Test cumsum on variable x + var = m.variables["x"] + cumsum = (multiple * var).cumsum() + cumsum.nterm == 2 + + # Test cumsum on sum of variables + expr = m.variables["x"] + m.variables["y"] + cumsum = (multiple * expr).cumsum() + cumsum.nterm == 2 + + +def test_simplify_basic(x: Variable) -> None: + """Test basic simplification with duplicate terms.""" + expr = 2 * x + 3 * x + 1 * x + simplified = expr.simplify() + assert simplified.nterm == 1, f"Expected 1 term, got {simplified.nterm}" + + x_len = len(x.coords["dim_0"]) + # Check that the coefficient is 6 (2 + 3 + 1) + coeffs: np.ndarray = simplified.coeffs.values + assert len(coeffs) == x_len, f"Expected {x_len} coefficients, got {len(coeffs)}" + assert all(coeffs == 6.0), f"Expected coefficient 6.0, got {coeffs[0]}" + + +def test_simplify_multiple_dimensions() -> None: + model = Model() + a_index = pd.Index([0, 1, 2, 3], name="a") + b_index = pd.Index([0, 1, 2], name="b") + coords = [a_index, b_index] + x = model.add_variables(name="x", coords=coords) + + expr = 2 * x + 3 * x + x + # Simplify + simplified = expr.simplify() + assert simplified.nterm == 1, f"Expected 1 term, got {simplified.nterm}" + assert simplified.ndim == 2, f"Expected 2 dimensions, got {simplified.ndim}" + assert all(simplified.coeffs.values.reshape(-1) == 6), ( + f"Expected coefficients of 6, got {simplified.coeffs.values}" + ) + + +def test_simplify_with_different_variables(x: Variable, y: Variable) -> None: + """Test that different variables are kept separate.""" + # Create expression: 2*x + 3*x + 4*y + expr = 2 * x + 3 * x + 4 * y + + # Simplify + simplified = expr.simplify() + # Should have 2 terms (one for x with coeff 5, one for y with coeff 4) + assert simplified.nterm == 2, f"Expected 2 terms, got {simplified.nterm}" + + coeffs: list[float] = simplified.coeffs.values.flatten().tolist() + assert set(coeffs) == {5.0, 4.0}, ( + f"Expected coefficients {{5.0, 4.0}}, got {set(coeffs)}" + ) + + +def test_simplify_with_constant(x: Variable) -> None: + """Test that constants are preserved.""" + expr = 2 * x + 3 * x + 10 + + # Simplify + simplified = expr.simplify() + + # Check constant is preserved + assert all(simplified.const.values == 10.0), ( + f"Expected constant 10.0, got {simplified.const.values}" + ) + + # Check coefficients + assert all(simplified.coeffs.values == 5.0), ( + f"Expected coefficient 5.0, got {simplified.coeffs.values}" + ) + + +def test_simplify_cancellation(x: Variable) -> None: + """Test that terms cancel out correctly when coefficients sum to zero.""" + expr = x - x + simplified = expr.simplify() + + assert simplified.nterm == 0, f"Expected 0 terms, got {simplified.nterm}" + assert simplified.coeffs.values.size == 0 + assert simplified.vars.values.size == 0 + + +def test_simplify_partial_cancellation(x: Variable, y: Variable) -> None: + """Test partial cancellation where some terms cancel but others remain.""" + expr = 2 * x - 2 * x + 3 * y + simplified = expr.simplify() + + assert simplified.nterm == 1, f"Expected 1 term, got {simplified.nterm}" + assert all(simplified.coeffs.values == 3.0), ( + f"Expected coefficient 3.0, got {simplified.coeffs.values}" + ) + + +def test_constant_only_expression_mul_dataarray(m: Model) -> None: + const_arr = xr.DataArray([2, 3], dims=["dim_0"]) + const_expr = LinearExpression(const_arr, m) + assert const_expr.is_constant + assert const_expr.nterm == 0 + + data_arr = xr.DataArray([10, 20], dims=["dim_0"]) + expected_const = const_arr * data_arr + + result = const_expr * data_arr + assert isinstance(result, LinearExpression) + assert result.is_constant + assert (result.const == expected_const).all() + + result_rev = data_arr * const_expr + assert isinstance(result_rev, LinearExpression) + assert result_rev.is_constant + assert (result_rev.const == expected_const).all() + + +def test_constant_only_expression_mul_linexpr_with_vars(m: Model, x: Variable) -> None: + const_arr = xr.DataArray([2, 3], dims=["dim_0"]) + const_expr = LinearExpression(const_arr, m) + assert const_expr.is_constant + assert const_expr.nterm == 0 + + expr_with_vars = 1 * x + 5 + expected_coeffs = const_arr + expected_const = const_arr * 5 + + result = const_expr * expr_with_vars + assert isinstance(result, LinearExpression) + assert (result.coeffs == expected_coeffs).all() + assert (result.const == expected_const).all() + + result_rev = expr_with_vars * const_expr + assert isinstance(result_rev, LinearExpression) + assert (result_rev.coeffs == expected_coeffs).all() + assert (result_rev.const == expected_const).all() + + +def test_constant_only_expression_mul_constant_only(m: Model) -> None: + const_arr = xr.DataArray([2, 3], dims=["dim_0"]) + const_arr2 = xr.DataArray([4, 5], dims=["dim_0"]) + const_expr = LinearExpression(const_arr, m) + const_expr2 = LinearExpression(const_arr2, m) + assert const_expr.is_constant + assert const_expr2.is_constant + + expected_const = const_arr * const_arr2 + + result = const_expr * const_expr2 + assert isinstance(result, LinearExpression) + assert result.is_constant + assert (result.const == expected_const).all() + + result_rev = const_expr2 * const_expr + assert isinstance(result_rev, LinearExpression) + assert result_rev.is_constant + assert (result_rev.const == expected_const).all() + + +def test_constant_only_expression_mul_linexpr_with_vars_and_const( + m: Model, x: Variable +) -> None: + const_arr = xr.DataArray([2, 3], dims=["dim_0"]) + const_expr = LinearExpression(const_arr, m) + assert const_expr.is_constant + + expr_with_vars_and_const = 4 * x + 10 + expected_coeffs = const_arr * 4 + expected_const = const_arr * 10 + + result = const_expr * expr_with_vars_and_const + assert isinstance(result, LinearExpression) + assert not result.is_constant + assert (result.coeffs == expected_coeffs).all() + assert (result.const == expected_const).all() + + result_rev = expr_with_vars_and_const * const_expr + assert isinstance(result_rev, LinearExpression) + assert not result_rev.is_constant + assert (result_rev.coeffs == expected_coeffs).all() + assert (result_rev.const == expected_const).all() + + +class TestJoinParameter: + @pytest.fixture + def m2(self) -> Model: + m = Model() + m.add_variables(coords=[pd.Index([0, 1, 2], name="i")], name="a") + m.add_variables(coords=[pd.Index([1, 2, 3], name="i")], name="b") + m.add_variables(coords=[pd.Index([0, 1, 2], name="i")], name="c") + return m + + @pytest.fixture + def a(self, m2: Model) -> Variable: + return m2.variables["a"] + + @pytest.fixture + def b(self, m2: Model) -> Variable: + return m2.variables["b"] + + @pytest.fixture + def c(self, m2: Model) -> Variable: + return m2.variables["c"] + + class TestAddition: + def test_add_join_none_preserves_default( + self, a: Variable, b: Variable + ) -> None: + result_default = a.to_linexpr() + b.to_linexpr() + result_none = a.to_linexpr().add(b.to_linexpr(), join=None) + assert_linequal(result_default, result_none) + + def test_add_expr_join_inner(self, a: Variable, b: Variable) -> None: + result = a.to_linexpr().add(b.to_linexpr(), join="inner") + assert list(result.data.indexes["i"]) == [1, 2] + + def test_add_expr_join_outer(self, a: Variable, b: Variable) -> None: + result = a.to_linexpr().add(b.to_linexpr(), join="outer") + assert list(result.data.indexes["i"]) == [0, 1, 2, 3] + + def test_add_expr_join_left(self, a: Variable, b: Variable) -> None: + result = a.to_linexpr().add(b.to_linexpr(), join="left") + assert list(result.data.indexes["i"]) == [0, 1, 2] + + def test_add_expr_join_right(self, a: Variable, b: Variable) -> None: + result = a.to_linexpr().add(b.to_linexpr(), join="right") + assert list(result.data.indexes["i"]) == [1, 2, 3] + + def test_add_constant_join_inner(self, a: Variable) -> None: + const = xr.DataArray([10, 20, 30], dims=["i"], coords={"i": [1, 2, 3]}) + result = a.to_linexpr().add(const, join="inner") + assert list(result.data.indexes["i"]) == [1, 2] + + def test_add_constant_join_outer(self, a: Variable) -> None: + const = xr.DataArray([10, 20, 30], dims=["i"], coords={"i": [1, 2, 3]}) + result = a.to_linexpr().add(const, join="outer") + assert list(result.data.indexes["i"]) == [0, 1, 2, 3] + + def test_add_constant_join_override(self, a: Variable, c: Variable) -> None: + expr = a.to_linexpr() + const = xr.DataArray([10, 20, 30], dims=["i"], coords={"i": [0, 1, 2]}) + result = expr.add(const, join="override") + assert list(result.data.indexes["i"]) == [0, 1, 2] + assert (result.const.values == const.values).all() + + def test_add_same_coords_all_joins(self, a: Variable, c: Variable) -> None: + expr_a = 1 * a + 5 + const = xr.DataArray([1, 2, 3], dims=["i"], coords={"i": [0, 1, 2]}) + for join in ["override", "outer", "inner"]: + result = expr_a.add(const, join=join) + assert list(result.coords["i"].values) == [0, 1, 2] + np.testing.assert_array_equal(result.const.values, [6, 7, 8]) + + def test_add_scalar_with_explicit_join(self, a: Variable) -> None: + expr = 1 * a + 5 + result = expr.add(10, join="override") + np.testing.assert_array_equal(result.const.values, [15, 15, 15]) + assert list(result.coords["i"].values) == [0, 1, 2] + + class TestSubtraction: + def test_sub_expr_join_inner(self, a: Variable, b: Variable) -> None: + result = a.to_linexpr().sub(b.to_linexpr(), join="inner") + assert list(result.data.indexes["i"]) == [1, 2] + + def test_sub_constant_override(self, a: Variable) -> None: + expr = 1 * a + 5 + other = xr.DataArray([10, 20, 30], dims=["i"], coords={"i": [5, 6, 7]}) + result = expr.sub(other, join="override") + assert list(result.coords["i"].values) == [0, 1, 2] + np.testing.assert_array_equal(result.const.values, [-5, -15, -25]) + + class TestMultiplication: + def test_mul_constant_join_inner(self, a: Variable) -> None: + const = xr.DataArray([2, 3, 4], dims=["i"], coords={"i": [1, 2, 3]}) + result = a.to_linexpr().mul(const, join="inner") + assert list(result.data.indexes["i"]) == [1, 2] + + def test_mul_constant_join_outer(self, a: Variable) -> None: + const = xr.DataArray([2, 3, 4], dims=["i"], coords={"i": [1, 2, 3]}) + result = a.to_linexpr().mul(const, join="outer") + assert list(result.data.indexes["i"]) == [0, 1, 2, 3] + assert result.coeffs.sel(i=0).item() == 0 + assert result.coeffs.sel(i=1).item() == 2 + assert result.coeffs.sel(i=2).item() == 3 + + def test_mul_expr_with_join_raises(self, a: Variable, b: Variable) -> None: + with pytest.raises(TypeError, match="join parameter is not supported"): + a.to_linexpr().mul(b.to_linexpr(), join="inner") + + class TestDivision: + def test_div_constant_join_inner(self, a: Variable) -> None: + const = xr.DataArray([2, 3, 4], dims=["i"], coords={"i": [1, 2, 3]}) + result = a.to_linexpr().div(const, join="inner") + assert list(result.data.indexes["i"]) == [1, 2] + + def test_div_constant_join_outer(self, a: Variable) -> None: + const = xr.DataArray([2, 3, 4], dims=["i"], coords={"i": [1, 2, 3]}) + result = a.to_linexpr().div(const, join="outer") + assert list(result.data.indexes["i"]) == [0, 1, 2, 3] + + def test_div_expr_with_join_raises(self, a: Variable, b: Variable) -> None: + with pytest.raises(TypeError): + a.to_linexpr().div(b.to_linexpr(), join="outer") + + class TestVariableOperations: + def test_variable_add_join(self, a: Variable, b: Variable) -> None: + result = a.add(b, join="inner") + assert list(result.data.indexes["i"]) == [1, 2] + + def test_variable_sub_join(self, a: Variable, b: Variable) -> None: + result = a.sub(b, join="inner") + assert list(result.data.indexes["i"]) == [1, 2] + + def test_variable_mul_join(self, a: Variable) -> None: + const = xr.DataArray([2, 3, 4], dims=["i"], coords={"i": [1, 2, 3]}) + result = a.mul(const, join="inner") + assert list(result.data.indexes["i"]) == [1, 2] + + def test_variable_div_join(self, a: Variable) -> None: + const = xr.DataArray([2, 3, 4], dims=["i"], coords={"i": [1, 2, 3]}) + result = a.div(const, join="inner") + assert list(result.data.indexes["i"]) == [1, 2] + + def test_variable_add_outer_values(self, a: Variable, b: Variable) -> None: + result = a.add(b, join="outer") + assert isinstance(result, LinearExpression) + assert set(result.coords["i"].values) == {0, 1, 2, 3} + assert result.nterm == 2 + + def test_variable_mul_override(self, a: Variable) -> None: + other = xr.DataArray([2, 3, 4], dims=["i"], coords={"i": [5, 6, 7]}) + result = a.mul(other, join="override") + assert isinstance(result, LinearExpression) + assert list(result.coords["i"].values) == [0, 1, 2] + np.testing.assert_array_equal(result.coeffs.squeeze().values, [2, 3, 4]) + + def test_variable_div_override(self, a: Variable) -> None: + other = xr.DataArray([2.0, 5.0, 10.0], dims=["i"], coords={"i": [5, 6, 7]}) + result = a.div(other, join="override") + assert isinstance(result, LinearExpression) + assert list(result.coords["i"].values) == [0, 1, 2] + np.testing.assert_array_almost_equal( + result.coeffs.squeeze().values, [0.5, 0.2, 0.1] + ) + + def test_same_shape_add_join_override(self, a: Variable, c: Variable) -> None: + result = a.to_linexpr().add(c.to_linexpr(), join="override") + assert list(result.data.indexes["i"]) == [0, 1, 2] + + class TestMerge: + def test_merge_join_parameter(self, a: Variable, b: Variable) -> None: + result: LinearExpression = merge( + [a.to_linexpr(), b.to_linexpr()], join="inner" + ) + assert list(result.data.indexes["i"]) == [1, 2] + + def test_merge_outer_join(self, a: Variable, b: Variable) -> None: + result: LinearExpression = merge( + [a.to_linexpr(), b.to_linexpr()], join="outer" + ) + assert set(result.coords["i"].values) == {0, 1, 2, 3} + + def test_merge_join_left(self, a: Variable, b: Variable) -> None: + result: LinearExpression = merge( + [a.to_linexpr(), b.to_linexpr()], join="left" + ) + assert list(result.data.indexes["i"]) == [0, 1, 2] + + def test_merge_join_right(self, a: Variable, b: Variable) -> None: + result: LinearExpression = merge( + [a.to_linexpr(), b.to_linexpr()], join="right" + ) + assert list(result.data.indexes["i"]) == [1, 2, 3] + + class TestValueVerification: + def test_add_expr_outer_const_values(self, a: Variable, b: Variable) -> None: + expr_a = 1 * a + 5 + expr_b = 2 * b + 10 + result = expr_a.add(expr_b, join="outer") + assert set(result.coords["i"].values) == {0, 1, 2, 3} + assert result.const.sel(i=0).item() == 5 + assert result.const.sel(i=1).item() == 15 + assert result.const.sel(i=2).item() == 15 + assert result.const.sel(i=3).item() == 10 + + def test_add_expr_inner_const_values(self, a: Variable, b: Variable) -> None: + expr_a = 1 * a + 5 + expr_b = 2 * b + 10 + result = expr_a.add(expr_b, join="inner") + assert list(result.coords["i"].values) == [1, 2] + assert result.const.sel(i=1).item() == 15 + assert result.const.sel(i=2).item() == 15 + + def test_add_constant_outer_fill_values(self, a: Variable) -> None: + expr = 1 * a + 5 + const = xr.DataArray([10, 20], dims=["i"], coords={"i": [1, 3]}) + result = expr.add(const, join="outer") + assert set(result.coords["i"].values) == {0, 1, 2, 3} + assert result.const.sel(i=0).item() == 5 + assert result.const.sel(i=1).item() == 15 + assert result.const.sel(i=2).item() == 5 + assert result.const.sel(i=3).item() == 20 + + def test_add_constant_inner_fill_values(self, a: Variable) -> None: + expr = 1 * a + 5 + const = xr.DataArray([10, 20], dims=["i"], coords={"i": [1, 3]}) + result = expr.add(const, join="inner") + assert list(result.coords["i"].values) == [1] + assert result.const.sel(i=1).item() == 15 + + def test_add_constant_override_positional(self, a: Variable) -> None: + expr = 1 * a + 5 + other = xr.DataArray([10, 20, 30], dims=["i"], coords={"i": [5, 6, 7]}) + result = expr.add(other, join="override") + assert list(result.coords["i"].values) == [0, 1, 2] + np.testing.assert_array_equal(result.const.values, [15, 25, 35]) + + def test_sub_expr_outer_const_values(self, a: Variable, b: Variable) -> None: + expr_a = 1 * a + 5 + expr_b = 2 * b + 10 + result = expr_a.sub(expr_b, join="outer") + assert set(result.coords["i"].values) == {0, 1, 2, 3} + assert result.const.sel(i=0).item() == 5 + assert result.const.sel(i=1).item() == -5 + assert result.const.sel(i=2).item() == -5 + assert result.const.sel(i=3).item() == -10 + + def test_mul_constant_override_positional(self, a: Variable) -> None: + expr = 1 * a + 5 + other = xr.DataArray([2, 3, 4], dims=["i"], coords={"i": [5, 6, 7]}) + result = expr.mul(other, join="override") + assert list(result.coords["i"].values) == [0, 1, 2] + np.testing.assert_array_equal(result.const.values, [10, 15, 20]) + np.testing.assert_array_equal(result.coeffs.squeeze().values, [2, 3, 4]) + + def test_mul_constant_outer_fill_values(self, a: Variable) -> None: + expr = 1 * a + 5 + other = xr.DataArray([2, 3], dims=["i"], coords={"i": [1, 3]}) + result = expr.mul(other, join="outer") + assert set(result.coords["i"].values) == {0, 1, 2, 3} + assert result.const.sel(i=0).item() == 0 + assert result.const.sel(i=1).item() == 10 + assert result.const.sel(i=2).item() == 0 + assert result.const.sel(i=3).item() == 0 + assert result.coeffs.squeeze().sel(i=1).item() == 2 + assert result.coeffs.squeeze().sel(i=0).item() == 0 + + def test_div_constant_override_positional(self, a: Variable) -> None: + expr = 1 * a + 10 + other = xr.DataArray([2.0, 5.0, 10.0], dims=["i"], coords={"i": [5, 6, 7]}) + result = expr.div(other, join="override") + assert list(result.coords["i"].values) == [0, 1, 2] + np.testing.assert_array_equal(result.const.values, [5.0, 2.0, 1.0]) + + def test_div_constant_outer_fill_values(self, a: Variable) -> None: + expr = 1 * a + 10 + other = xr.DataArray([2.0, 5.0], dims=["i"], coords={"i": [1, 3]}) + result = expr.div(other, join="outer") + assert set(result.coords["i"].values) == {0, 1, 2, 3} + assert result.const.sel(i=1).item() == pytest.approx(5.0) + assert result.coeffs.squeeze().sel(i=1).item() == pytest.approx(0.5) + assert result.const.sel(i=0).item() == pytest.approx(10.0) + assert result.coeffs.squeeze().sel(i=0).item() == pytest.approx(1.0) + + class TestQuadratic: + def test_quadratic_add_constant_join_inner( + self, a: Variable, b: Variable + ) -> None: + quad = a.to_linexpr() * b.to_linexpr() + const = xr.DataArray([10, 20, 30], dims=["i"], coords={"i": [1, 2, 3]}) + result = quad.add(const, join="inner") + assert list(result.data.indexes["i"]) == [1, 2, 3] + + def test_quadratic_add_expr_join_inner(self, a: Variable) -> None: + quad = a.to_linexpr() * a.to_linexpr() + const = xr.DataArray([10, 20], dims=["i"], coords={"i": [0, 1]}) + result = quad.add(const, join="inner") + assert list(result.data.indexes["i"]) == [0, 1] + + def test_quadratic_mul_constant_join_inner( + self, a: Variable, b: Variable + ) -> None: + quad = a.to_linexpr() * b.to_linexpr() + const = xr.DataArray([2, 3, 4], dims=["i"], coords={"i": [1, 2, 3]}) + result = quad.mul(const, join="inner") + assert list(result.data.indexes["i"]) == [1, 2, 3] diff --git a/test/test_typing.py b/test/test_typing.py index 2375dc72..e6ba7ffb 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -1,8 +1,17 @@ +import pytest import xarray as xr import linopy +@pytest.fixture(autouse=True) +def _use_exact_join(): + """Use exact arithmetic join for all tests in this module.""" + linopy.options["arithmetic_join"] = "exact" + yield + linopy.options["arithmetic_join"] = "legacy" + + def test_operations_with_data_arrays_are_typed_correctly() -> None: m = linopy.Model() diff --git a/test/test_typing_legacy.py b/test/test_typing_legacy.py new file mode 100644 index 00000000..99a27033 --- /dev/null +++ b/test/test_typing_legacy.py @@ -0,0 +1,25 @@ +import xarray as xr + +import linopy + + +def test_operations_with_data_arrays_are_typed_correctly() -> None: + m = linopy.Model() + + a: xr.DataArray = xr.DataArray([1, 2, 3]) + + v: linopy.Variable = m.add_variables(lower=0.0, name="v") + e: linopy.LinearExpression = v * 1.0 + q = v * v + + _ = a * v + _ = v * a + _ = v + a + + _ = a * e + _ = e * a + _ = e + a + + _ = a * q + _ = q * a + _ = q + a From bf2f65825396dd321d90f6a894d1cf776d7a0762 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Tue, 10 Mar 2026 08:30:51 +0100 Subject: [PATCH 2/9] Simplify global setting to 'legacy'/'v1', add LinopyDeprecationWarning - Restrict options["arithmetic_join"] to {"legacy", "v1"} instead of exposing all xarray join values (explicit join= parameter still accepts any) - "v1" maps to "exact" join internally - Add LinopyDeprecationWarning class (subclass of FutureWarning) with centralized message including how to silence - Export LinopyDeprecationWarning from linopy.__init__ Co-Authored-By: Claude Opus 4.6 --- linopy/__init__.py | 3 ++- linopy/common.py | 11 +++++++---- linopy/config.py | 21 ++++++++++++--------- linopy/expressions.py | 28 +++++++++++++++------------- test/conftest.py | 2 +- test/test_algebraic_properties.py | 2 +- test/test_common.py | 2 +- test/test_constraints.py | 2 +- test/test_linear_expression.py | 2 +- test/test_typing.py | 2 +- 10 files changed, 42 insertions(+), 33 deletions(-) diff --git a/linopy/__init__.py b/linopy/__init__.py index 7f5acd46..d96f2d31 100644 --- a/linopy/__init__.py +++ b/linopy/__init__.py @@ -13,7 +13,7 @@ # we need to extend their __mul__ functions with a quick special case import linopy.monkey_patch_xarray # noqa: F401 from linopy.common import align -from linopy.config import options +from linopy.config import LinopyDeprecationWarning, options from linopy.constants import EQUAL, GREATER_EQUAL, LESS_EQUAL from linopy.constraints import Constraint, Constraints from linopy.expressions import LinearExpression, QuadraticExpression, merge @@ -29,6 +29,7 @@ "EQUAL", "GREATER_EQUAL", "LESS_EQUAL", + "LinopyDeprecationWarning", "LinearExpression", "Model", "Objective", diff --git a/linopy/common.py b/linopy/common.py index ea1b46d9..c820147f 100644 --- a/linopy/common.py +++ b/linopy/common.py @@ -1275,15 +1275,18 @@ def align( join = options["arithmetic_join"] if join == "legacy": + from linopy.config import LEGACY_DEPRECATION_MESSAGE, LinopyDeprecationWarning + warn( - "The 'legacy' arithmetic join is deprecated and will be removed " - "in a future version. Set linopy.options['arithmetic_join'] = " - "'exact' to opt in to the new behavior.", - FutureWarning, + LEGACY_DEPRECATION_MESSAGE, + LinopyDeprecationWarning, stacklevel=2, ) join = "inner" + elif join == "v1": + join = "exact" + # Extract underlying Datasets for index computation. das: list[Any] = [] for obj in objects: diff --git a/linopy/config.py b/linopy/config.py index c5637ce2..63143d46 100644 --- a/linopy/config.py +++ b/linopy/config.py @@ -9,15 +9,18 @@ from typing import Any -VALID_ARITHMETIC_JOINS = { - "exact", - "inner", - "outer", - "left", - "right", - "override", - "legacy", -} +VALID_ARITHMETIC_JOINS = {"legacy", "v1"} + +LEGACY_DEPRECATION_MESSAGE = ( + "The 'legacy' arithmetic join is deprecated and will be removed in a " + "future version. Set linopy.options['arithmetic_join'] = 'v1' to opt in " + "to the new behavior, or filter this warning with:\n" + " import warnings; warnings.filterwarnings('ignore', category=LinopyDeprecationWarning)" +) + + +class LinopyDeprecationWarning(FutureWarning): + """Warning for deprecated linopy features scheduled for removal.""" class OptionSettings: diff --git a/linopy/expressions.py b/linopy/expressions.py index bb08870b..7dda60a8 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -66,7 +66,7 @@ to_dataframe, to_polars, ) -from linopy.config import options +from linopy.config import LEGACY_DEPRECATION_MESSAGE, LinopyDeprecationWarning, options from linopy.constants import ( CV_DIM, EQUAL, @@ -554,10 +554,8 @@ def _align_constant( if join == "legacy": warn( - "The 'legacy' arithmetic join is deprecated and will be removed in a " - "future version. Set linopy.options['arithmetic_join'] = 'exact' to " - "opt in to the new behavior.", - FutureWarning, + LEGACY_DEPRECATION_MESSAGE, + LinopyDeprecationWarning, stacklevel=4, ) # Old behavior: override when same sizes, left join otherwise @@ -569,6 +567,9 @@ def _align_constant( False, ) + elif join == "v1": + join = "exact" + if join == "override": return self.const, other.assign_coords(coords=self.coords), False elif join == "left": @@ -1110,10 +1111,8 @@ def to_constraint( if effective_join == "legacy": warn( - "The 'legacy' arithmetic join is deprecated and will be removed " - "in a future version. Set linopy.options['arithmetic_join'] = " - "'exact' to opt in to the new behavior.", - FutureWarning, + LEGACY_DEPRECATION_MESSAGE, + LinopyDeprecationWarning, stacklevel=3, ) # Old behavior: convert to DataArray, warn about extra dims, @@ -1134,6 +1133,9 @@ def to_constraint( ) return constraints.Constraint(data, model=self.model) + if effective_join == "v1": + effective_join = "exact" + if isinstance(rhs, DataArray): if effective_join == "override": aligned_rhs = rhs.assign_coords(coords=self.const.coords) @@ -2450,10 +2452,8 @@ def merge( if effective_join == "legacy": warn( - "The 'legacy' arithmetic join is deprecated and will be removed " - "in a future version. Set linopy.options['arithmetic_join'] = " - "'exact' to opt in to the new behavior.", - FutureWarning, + LEGACY_DEPRECATION_MESSAGE, + LinopyDeprecationWarning, stacklevel=2, ) # Reproduce old behavior: override when all shared dims have @@ -2471,6 +2471,8 @@ def merge( override = False kwargs["join"] = "override" if override else "outer" + elif effective_join == "v1": + kwargs["join"] = "exact" else: kwargs["join"] = effective_join diff --git a/test/conftest.py b/test/conftest.py index 8a4343d6..02f1923d 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -57,7 +57,7 @@ def pytest_collection_modifyitems( @pytest.fixture def exact_join(): """Set arithmetic_join to 'exact' for the duration of a test.""" - linopy.options["arithmetic_join"] = "exact" + linopy.options["arithmetic_join"] = "v1" yield linopy.options["arithmetic_join"] = "legacy" diff --git a/test/test_algebraic_properties.py b/test/test_algebraic_properties.py index 74e9e8dd..b360b0cd 100644 --- a/test/test_algebraic_properties.py +++ b/test/test_algebraic_properties.py @@ -50,7 +50,7 @@ @pytest.fixture(autouse=True) def _use_exact_join(): """Use exact arithmetic join for all tests in this module.""" - linopy.options["arithmetic_join"] = "exact" + linopy.options["arithmetic_join"] = "v1" yield linopy.options["arithmetic_join"] = "legacy" diff --git a/test/test_common.py b/test/test_common.py index 72171211..c0389506 100644 --- a/test/test_common.py +++ b/test/test_common.py @@ -31,7 +31,7 @@ @pytest.fixture(autouse=True) def _use_exact_join(): """Use exact arithmetic join for all tests in this module.""" - linopy.options["arithmetic_join"] = "exact" + linopy.options["arithmetic_join"] = "v1" yield linopy.options["arithmetic_join"] = "legacy" diff --git a/test/test_constraints.py b/test/test_constraints.py index 55f92f6e..46bf9ad0 100644 --- a/test/test_constraints.py +++ b/test/test_constraints.py @@ -20,7 +20,7 @@ @pytest.fixture(autouse=True) def _use_exact_join(): """Use exact arithmetic join for all tests in this module.""" - linopy.options["arithmetic_join"] = "exact" + linopy.options["arithmetic_join"] = "v1" yield linopy.options["arithmetic_join"] = "legacy" diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index 81e5737d..44f4120e 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -25,7 +25,7 @@ @pytest.fixture(autouse=True) def _use_exact_join(): """Use exact arithmetic join for all tests in this module.""" - linopy.options["arithmetic_join"] = "exact" + linopy.options["arithmetic_join"] = "v1" yield linopy.options["arithmetic_join"] = "legacy" diff --git a/test/test_typing.py b/test/test_typing.py index e6ba7ffb..b8b760ae 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -7,7 +7,7 @@ @pytest.fixture(autouse=True) def _use_exact_join(): """Use exact arithmetic join for all tests in this module.""" - linopy.options["arithmetic_join"] = "exact" + linopy.options["arithmetic_join"] = "v1" yield linopy.options["arithmetic_join"] = "legacy" From 07003b15fd6ab5f048e614e9110e469c4e680e71 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Tue, 10 Mar 2026 08:36:05 +0100 Subject: [PATCH 3/9] Rename arithmetic_join to arithmetic_convention, mention v1 removal - Rename setting from 'arithmetic_join' to 'arithmetic_convention' - Update deprecation message: "will be removed in linopy v1" Co-Authored-By: Claude Opus 4.6 --- linopy/common.py | 2 +- linopy/config.py | 10 +++++----- linopy/expressions.py | 8 ++++---- test/conftest.py | 6 +++--- test/test_algebraic_properties.py | 4 ++-- test/test_common.py | 4 ++-- test/test_constraints.py | 4 ++-- test/test_linear_expression.py | 4 ++-- test/test_typing.py | 4 ++-- 9 files changed, 23 insertions(+), 23 deletions(-) diff --git a/linopy/common.py b/linopy/common.py index c820147f..09aad415 100644 --- a/linopy/common.py +++ b/linopy/common.py @@ -1272,7 +1272,7 @@ def align( from linopy.variables import Variable if join is None: - join = options["arithmetic_join"] + join = options["arithmetic_convention"] if join == "legacy": from linopy.config import LEGACY_DEPRECATION_MESSAGE, LinopyDeprecationWarning diff --git a/linopy/config.py b/linopy/config.py index 63143d46..9f04ce17 100644 --- a/linopy/config.py +++ b/linopy/config.py @@ -12,8 +12,8 @@ VALID_ARITHMETIC_JOINS = {"legacy", "v1"} LEGACY_DEPRECATION_MESSAGE = ( - "The 'legacy' arithmetic join is deprecated and will be removed in a " - "future version. Set linopy.options['arithmetic_join'] = 'v1' to opt in " + "The 'legacy' arithmetic convention is deprecated and will be removed in " + "linopy v1. Set linopy.options['arithmetic_convention'] = 'v1' to opt in " "to the new behavior, or filter this warning with:\n" " import warnings; warnings.filterwarnings('ignore', category=LinopyDeprecationWarning)" ) @@ -41,9 +41,9 @@ def set_value(self, **kwargs: Any) -> None: for k, v in kwargs.items(): if k not in self._defaults: raise KeyError(f"{k} is not a valid setting.") - if k == "arithmetic_join" and v not in VALID_ARITHMETIC_JOINS: + if k == "arithmetic_convention" and v not in VALID_ARITHMETIC_JOINS: raise ValueError( - f"Invalid arithmetic_join: {v!r}. " + f"Invalid arithmetic_convention: {v!r}. " f"Must be one of {VALID_ARITHMETIC_JOINS}." ) self._current_values[k] = v @@ -78,5 +78,5 @@ def __repr__(self) -> str: options = OptionSettings( display_max_rows=14, display_max_terms=6, - arithmetic_join="legacy", + arithmetic_convention="legacy", ) diff --git a/linopy/expressions.py b/linopy/expressions.py index 7dda60a8..61a28d55 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -538,7 +538,7 @@ def _align_constant( fill_value : float, default: 0 Fill value for missing coordinates. join : str, optional - Alignment method. If None, uses ``options["arithmetic_join"]``. + Alignment method. If None, uses ``options["arithmetic_convention"]``. Returns ------- @@ -550,7 +550,7 @@ def _align_constant( Whether the expression's data needs reindexing. """ if join is None: - join = options["arithmetic_join"] + join = options["arithmetic_convention"] if join == "legacy": warn( @@ -1107,7 +1107,7 @@ def to_constraint( f"Both sides of the constraint are constant. At least one side must contain variables. {self} {rhs}" ) - effective_join = join if join is not None else options["arithmetic_join"] + effective_join = join if join is not None else options["arithmetic_convention"] if effective_join == "legacy": warn( @@ -2448,7 +2448,7 @@ def merge( elif cls == variables.Variable: kwargs["fill_value"] = variables.FILL_VALUE - effective_join = join if join is not None else options["arithmetic_join"] + effective_join = join if join is not None else options["arithmetic_convention"] if effective_join == "legacy": warn( diff --git a/test/conftest.py b/test/conftest.py index 02f1923d..a0834307 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -56,10 +56,10 @@ def pytest_collection_modifyitems( @pytest.fixture def exact_join(): - """Set arithmetic_join to 'exact' for the duration of a test.""" - linopy.options["arithmetic_join"] = "v1" + """Set arithmetic_convention to 'exact' for the duration of a test.""" + linopy.options["arithmetic_convention"] = "v1" yield - linopy.options["arithmetic_join"] = "legacy" + linopy.options["arithmetic_convention"] = "legacy" @pytest.fixture diff --git a/test/test_algebraic_properties.py b/test/test_algebraic_properties.py index b360b0cd..c763e2d8 100644 --- a/test/test_algebraic_properties.py +++ b/test/test_algebraic_properties.py @@ -50,9 +50,9 @@ @pytest.fixture(autouse=True) def _use_exact_join(): """Use exact arithmetic join for all tests in this module.""" - linopy.options["arithmetic_join"] = "v1" + linopy.options["arithmetic_convention"] = "v1" yield - linopy.options["arithmetic_join"] = "legacy" + linopy.options["arithmetic_convention"] = "legacy" @pytest.fixture diff --git a/test/test_common.py b/test/test_common.py index c0389506..75f556dd 100644 --- a/test/test_common.py +++ b/test/test_common.py @@ -31,9 +31,9 @@ @pytest.fixture(autouse=True) def _use_exact_join(): """Use exact arithmetic join for all tests in this module.""" - linopy.options["arithmetic_join"] = "v1" + linopy.options["arithmetic_convention"] = "v1" yield - linopy.options["arithmetic_join"] = "legacy" + linopy.options["arithmetic_convention"] = "legacy" # Fixtures m, u, x are provided by conftest.py diff --git a/test/test_constraints.py b/test/test_constraints.py index 46bf9ad0..39fa0f6d 100644 --- a/test/test_constraints.py +++ b/test/test_constraints.py @@ -20,9 +20,9 @@ @pytest.fixture(autouse=True) def _use_exact_join(): """Use exact arithmetic join for all tests in this module.""" - linopy.options["arithmetic_join"] = "v1" + linopy.options["arithmetic_convention"] = "v1" yield - linopy.options["arithmetic_join"] = "legacy" + linopy.options["arithmetic_convention"] = "legacy" # Test model functions diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index 44f4120e..41bdea74 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -25,9 +25,9 @@ @pytest.fixture(autouse=True) def _use_exact_join(): """Use exact arithmetic join for all tests in this module.""" - linopy.options["arithmetic_join"] = "v1" + linopy.options["arithmetic_convention"] = "v1" yield - linopy.options["arithmetic_join"] = "legacy" + linopy.options["arithmetic_convention"] = "legacy" # Fixtures m, x, y, z, v, u are provided by conftest.py diff --git a/test/test_typing.py b/test/test_typing.py index b8b760ae..fbabfa12 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -7,9 +7,9 @@ @pytest.fixture(autouse=True) def _use_exact_join(): """Use exact arithmetic join for all tests in this module.""" - linopy.options["arithmetic_join"] = "v1" + linopy.options["arithmetic_convention"] = "v1" yield - linopy.options["arithmetic_join"] = "legacy" + linopy.options["arithmetic_convention"] = "legacy" def test_operations_with_data_arrays_are_typed_correctly() -> None: From cc87789645c1371d066d35326eda8218c6e233ca Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Tue, 10 Mar 2026 09:28:06 +0100 Subject: [PATCH 4/9] Rename test fixtures from exact_join to v1_convention Co-Authored-By: Claude Opus 4.6 --- test/conftest.py | 4 ++-- test/test_algebraic_properties.py | 4 ++-- test/test_common.py | 4 ++-- test/test_constraints.py | 4 ++-- test/test_linear_expression.py | 4 ++-- test/test_typing.py | 4 ++-- 6 files changed, 12 insertions(+), 12 deletions(-) diff --git a/test/conftest.py b/test/conftest.py index a0834307..623f4934 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -55,8 +55,8 @@ def pytest_collection_modifyitems( @pytest.fixture -def exact_join(): - """Set arithmetic_convention to 'exact' for the duration of a test.""" +def v1_convention(): + """Set arithmetic_convention to 'v1' for the duration of a test.""" linopy.options["arithmetic_convention"] = "v1" yield linopy.options["arithmetic_convention"] = "legacy" diff --git a/test/test_algebraic_properties.py b/test/test_algebraic_properties.py index c763e2d8..1e8d5db8 100644 --- a/test/test_algebraic_properties.py +++ b/test/test_algebraic_properties.py @@ -48,8 +48,8 @@ @pytest.fixture(autouse=True) -def _use_exact_join(): - """Use exact arithmetic join for all tests in this module.""" +def _use_v1_convention(): + """Use v1 arithmetic convention for all tests in this module.""" linopy.options["arithmetic_convention"] = "v1" yield linopy.options["arithmetic_convention"] = "legacy" diff --git a/test/test_common.py b/test/test_common.py index 75f556dd..420b9bd7 100644 --- a/test/test_common.py +++ b/test/test_common.py @@ -29,8 +29,8 @@ @pytest.fixture(autouse=True) -def _use_exact_join(): - """Use exact arithmetic join for all tests in this module.""" +def _use_v1_convention(): + """Use v1 arithmetic convention for all tests in this module.""" linopy.options["arithmetic_convention"] = "v1" yield linopy.options["arithmetic_convention"] = "legacy" diff --git a/test/test_constraints.py b/test/test_constraints.py index 39fa0f6d..ff40391b 100644 --- a/test/test_constraints.py +++ b/test/test_constraints.py @@ -18,8 +18,8 @@ @pytest.fixture(autouse=True) -def _use_exact_join(): - """Use exact arithmetic join for all tests in this module.""" +def _use_v1_convention(): + """Use v1 arithmetic convention for all tests in this module.""" linopy.options["arithmetic_convention"] = "v1" yield linopy.options["arithmetic_convention"] = "legacy" diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index 41bdea74..9c8c58e6 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -23,8 +23,8 @@ @pytest.fixture(autouse=True) -def _use_exact_join(): - """Use exact arithmetic join for all tests in this module.""" +def _use_v1_convention(): + """Use v1 arithmetic convention for all tests in this module.""" linopy.options["arithmetic_convention"] = "v1" yield linopy.options["arithmetic_convention"] = "legacy" diff --git a/test/test_typing.py b/test/test_typing.py index fbabfa12..6e7d75f8 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -5,8 +5,8 @@ @pytest.fixture(autouse=True) -def _use_exact_join(): - """Use exact arithmetic join for all tests in this module.""" +def _use_v1_convention(): + """Use v1 arithmetic convention for all tests in this module.""" linopy.options["arithmetic_convention"] = "v1" yield linopy.options["arithmetic_convention"] = "legacy" From 4382b12da9805be88131d3a80e4e2232ddac6342 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Tue, 10 Mar 2026 13:04:37 +0100 Subject: [PATCH 5/9] Update legacy tests --- test/test_common_legacy.py | 1 - test/test_linear_expression_legacy.py | 130 +++++++++++++++++++------- 2 files changed, 94 insertions(+), 37 deletions(-) diff --git a/test/test_common_legacy.py b/test/test_common_legacy.py index 7e623bf6..f1190024 100644 --- a/test/test_common_legacy.py +++ b/test/test_common_legacy.py @@ -649,7 +649,6 @@ def test_get_dims_with_index_levels() -> None: assert get_dims_with_index_levels(ds5) == [] -@pytest.mark.xfail(reason="xarray MultiIndex alignment incompatibility") def test_align(x: Variable, u: Variable) -> None: # noqa: F811 alpha = xr.DataArray([1, 2], [[1, 2]]) beta = xr.DataArray( diff --git a/test/test_linear_expression_legacy.py b/test/test_linear_expression_legacy.py index 2cfb315b..d3b8d426 100644 --- a/test/test_linear_expression_legacy.py +++ b/test/test_linear_expression_legacy.py @@ -1,4 +1,3 @@ -# ruff: noqa: D106 #!/usr/bin/env python3 """ Created on Wed Mar 17 17:06:36 2021. @@ -807,41 +806,51 @@ def test_subset_add_quadexpr(self, v: Variable, subset: xr.DataArray) -> None: assert_quadequal(subset + qexpr, qexpr + subset) class TestMissingValues: - """Same shape as variable but with NaN entries in the constant.""" + """ + Same shape as variable but with NaN entries in the constant. - EXPECTED_NAN_MASK = np.zeros(20, dtype=bool) - EXPECTED_NAN_MASK[[0, 5, 19]] = True + NaN values are filled with operation-specific neutral elements: + - Addition/subtraction: NaN -> 0 (additive identity) + - Multiplication: NaN -> 0 (zeroes out the variable) + - Division: NaN -> 1 (multiplicative identity, no scaling) + """ + + NAN_POSITIONS = [0, 5, 19] @pytest.mark.parametrize("operand", ["var", "expr"]) - def test_add_nan_propagates( + def test_add_nan_filled( self, v: Variable, nan_constant: xr.DataArray | pd.Series, operand: str, ) -> None: + base_const = 0.0 if operand == "var" else 5.0 target = v if operand == "var" else v + 5 result = target + nan_constant assert result.sizes["dim_2"] == 20 - np.testing.assert_array_equal( - np.isnan(result.const.values), self.EXPECTED_NAN_MASK - ) + assert not np.isnan(result.const.values).any() + # At NaN positions, const should be unchanged (added 0) + for i in self.NAN_POSITIONS: + assert result.const.values[i] == base_const @pytest.mark.parametrize("operand", ["var", "expr"]) - def test_sub_nan_propagates( + def test_sub_nan_filled( self, v: Variable, nan_constant: xr.DataArray | pd.Series, operand: str, ) -> None: + base_const = 0.0 if operand == "var" else 5.0 target = v if operand == "var" else v + 5 result = target - nan_constant assert result.sizes["dim_2"] == 20 - np.testing.assert_array_equal( - np.isnan(result.const.values), self.EXPECTED_NAN_MASK - ) + assert not np.isnan(result.const.values).any() + # At NaN positions, const should be unchanged (subtracted 0) + for i in self.NAN_POSITIONS: + assert result.const.values[i] == base_const @pytest.mark.parametrize("operand", ["var", "expr"]) - def test_mul_nan_propagates( + def test_mul_nan_filled( self, v: Variable, nan_constant: xr.DataArray | pd.Series, @@ -850,12 +859,13 @@ def test_mul_nan_propagates( target = v if operand == "var" else 1 * v result = target * nan_constant assert result.sizes["dim_2"] == 20 - np.testing.assert_array_equal( - np.isnan(result.coeffs.squeeze().values), self.EXPECTED_NAN_MASK - ) + assert not np.isnan(result.coeffs.squeeze().values).any() + # At NaN positions, coeffs should be 0 (variable zeroed out) + for i in self.NAN_POSITIONS: + assert result.coeffs.squeeze().values[i] == 0.0 @pytest.mark.parametrize("operand", ["var", "expr"]) - def test_div_nan_propagates( + def test_div_nan_filled( self, v: Variable, nan_constant: xr.DataArray | pd.Series, @@ -864,9 +874,11 @@ def test_div_nan_propagates( target = v if operand == "var" else 1 * v result = target / nan_constant assert result.sizes["dim_2"] == 20 - np.testing.assert_array_equal( - np.isnan(result.coeffs.squeeze().values), self.EXPECTED_NAN_MASK - ) + assert not np.isnan(result.coeffs.squeeze().values).any() + # At NaN positions, coeffs should be unchanged (divided by 1) + original_coeffs = (1 * v).coeffs.squeeze().values + for i in self.NAN_POSITIONS: + assert result.coeffs.squeeze().values[i] == original_coeffs[i] def test_add_commutativity( self, @@ -875,14 +887,9 @@ def test_add_commutativity( ) -> None: result_a = v + nan_constant result_b = nan_constant + v - # Compare non-NaN values are equal and NaN positions match - nan_mask_a = np.isnan(result_a.const.values) - nan_mask_b = np.isnan(result_b.const.values) - np.testing.assert_array_equal(nan_mask_a, nan_mask_b) - np.testing.assert_array_equal( - result_a.const.values[~nan_mask_a], - result_b.const.values[~nan_mask_b], - ) + assert not np.isnan(result_a.const.values).any() + assert not np.isnan(result_b.const.values).any() + np.testing.assert_array_equal(result_a.const.values, result_b.const.values) np.testing.assert_array_equal( result_a.coeffs.values, result_b.coeffs.values ) @@ -894,12 +901,10 @@ def test_mul_commutativity( ) -> None: result_a = v * nan_constant result_b = nan_constant * v - nan_mask_a = np.isnan(result_a.coeffs.values) - nan_mask_b = np.isnan(result_b.coeffs.values) - np.testing.assert_array_equal(nan_mask_a, nan_mask_b) + assert not np.isnan(result_a.coeffs.values).any() + assert not np.isnan(result_b.coeffs.values).any() np.testing.assert_array_equal( - result_a.coeffs.values[~nan_mask_a], - result_b.coeffs.values[~nan_mask_b], + result_a.coeffs.values, result_b.coeffs.values ) def test_quadexpr_add_nan( @@ -911,9 +916,62 @@ def test_quadexpr_add_nan( result = qexpr + nan_constant assert isinstance(result, QuadraticExpression) assert result.sizes["dim_2"] == 20 - np.testing.assert_array_equal( - np.isnan(result.const.values), self.EXPECTED_NAN_MASK - ) + assert not np.isnan(result.const.values).any() + + class TestExpressionWithNaN: + """Test that NaN in expression's own const/coeffs doesn't propagate.""" + + def test_shifted_expr_add_scalar(self, v: Variable) -> None: + expr = (1 * v).shift(dim_2=1) + result = expr + 5 + assert not np.isnan(result.const.values).any() + assert result.const.values[0] == 5.0 + + def test_shifted_expr_mul_scalar(self, v: Variable) -> None: + expr = (1 * v).shift(dim_2=1) + result = expr * 2 + assert not np.isnan(result.coeffs.squeeze().values).any() + assert result.coeffs.squeeze().values[0] == 0.0 + + def test_shifted_expr_add_array(self, v: Variable) -> None: + arr = np.arange(v.sizes["dim_2"], dtype=float) + expr = (1 * v).shift(dim_2=1) + result = expr + arr + assert not np.isnan(result.const.values).any() + assert result.const.values[0] == 0.0 + + def test_shifted_expr_mul_array(self, v: Variable) -> None: + arr = np.arange(v.sizes["dim_2"], dtype=float) + 1 + expr = (1 * v).shift(dim_2=1) + result = expr * arr + assert not np.isnan(result.coeffs.squeeze().values).any() + assert result.coeffs.squeeze().values[0] == 0.0 + + def test_shifted_expr_div_scalar(self, v: Variable) -> None: + expr = (1 * v).shift(dim_2=1) + result = expr / 2 + assert not np.isnan(result.coeffs.squeeze().values).any() + assert result.coeffs.squeeze().values[0] == 0.0 + + def test_shifted_expr_sub_scalar(self, v: Variable) -> None: + expr = (1 * v).shift(dim_2=1) + result = expr - 3 + assert not np.isnan(result.const.values).any() + assert result.const.values[0] == -3.0 + + def test_shifted_expr_div_array(self, v: Variable) -> None: + arr = np.arange(v.sizes["dim_2"], dtype=float) + 1 + expr = (1 * v).shift(dim_2=1) + result = expr / arr + assert not np.isnan(result.coeffs.squeeze().values).any() + assert result.coeffs.squeeze().values[0] == 0.0 + + def test_variable_to_linexpr_nan_coefficient(self, v: Variable) -> None: + nan_coeff = np.ones(v.sizes["dim_2"]) + nan_coeff[0] = np.nan + result = v.to_linexpr(nan_coeff) + assert not np.isnan(result.coeffs.squeeze().values).any() + assert result.coeffs.squeeze().values[0] == 0.0 class TestMultiDim: def test_multidim_subset_mul(self, m: Model) -> None: From 465637d12c3d41def02d2778b5bd738d3848bf27 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Tue, 10 Mar 2026 14:38:11 +0100 Subject: [PATCH 6/9] Merge harmonize-linopy-operations-mixed, restore NaN filling and align function - Resolve merge conflicts keeping transition layer logic - Restore NaN fillna(0) in _add_constant and _apply_constant_op - Restore simple finisher-based align() function (fixes MultiIndex) - Use check_common_keys_values in merge legacy path - Update legacy test files to match origin/harmonize-linopy-operations Co-Authored-By: Claude Opus 4.6 --- linopy/common.py | 48 +++++++++++++++++++++---------------------- linopy/expressions.py | 13 +++++++++--- 2 files changed, 33 insertions(+), 28 deletions(-) diff --git a/linopy/common.py b/linopy/common.py index 71890841..4bcf553a 100644 --- a/linopy/common.py +++ b/linopy/common.py @@ -10,7 +10,7 @@ import operator import os from collections.abc import Callable, Generator, Hashable, Iterable, Sequence -from functools import reduce, wraps +from functools import partial, reduce, wraps from pathlib import Path from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload from warnings import warn @@ -1267,38 +1267,36 @@ def align( elif join == "v1": join = "exact" - # Extract underlying Datasets for index computation. + finisher: list[partial[Any] | Callable[[Any], Any]] = [] das: list[Any] = [] for obj in objects: - if isinstance(obj, LinearExpression | QuadraticExpression | Variable): + if isinstance(obj, LinearExpression | QuadraticExpression): + finisher.append(partial(obj.__class__, model=obj.model)) + das.append(obj.data) + elif isinstance(obj, Variable): + finisher.append( + partial( + obj.__class__, + model=obj.model, + name=obj.data.attrs["name"], + skip_broadcast=True, + ) + ) das.append(obj.data) else: + finisher.append(lambda x: x) das.append(obj) exclude = frozenset(exclude).union(HELPER_DIMS) - - # Compute target indexes. - target_aligned = xr_align( - *das, join=join, copy=False, indexes=indexes, exclude=exclude + aligned = xr_align( + *das, + join=join, + copy=copy, + indexes=indexes, + exclude=exclude, + fill_value=fill_value, ) - - # Reindex each object to target indexes. - reindex_kwargs: dict[str, Any] = {} - if fill_value is not dtypes.NA: - reindex_kwargs["fill_value"] = fill_value - results: list[Any] = [] - for obj, target in zip(objects, target_aligned): - indexers = { - dim: target.indexes[dim] - for dim in target.dims - if dim not in exclude and dim in target.indexes - } - # Variable.reindex has no fill_value — it always uses sentinels - if isinstance(obj, Variable): - results.append(obj.reindex(indexers)) - else: - results.append(obj.reindex(indexers, **reindex_kwargs)) # type: ignore[union-attr] - return tuple(results) + return tuple([f(da) for f, da in zip(finisher, aligned)]) LocT = TypeVar( diff --git a/linopy/expressions.py b/linopy/expressions.py index 651cbeb2..418b8d16 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -624,11 +624,13 @@ def _add_constant( self: GenericExpression, other: ConstantLike, join: JoinOptions | None = None ) -> GenericExpression: if np.isscalar(other) and join is None: - return self.assign(const=self.const + other) + return self.assign(const=self.const.fillna(0) + other) da = as_dataarray(other, coords=self.coords, dims=self.coord_dims) self_const, da, needs_data_reindex = self._align_constant( da, fill_value=0, join=join ) + da = da.fillna(0) + self_const = self_const.fillna(0) if needs_data_reindex: fv = {**self._fill_value, "const": 0} return self.__class__( @@ -650,16 +652,21 @@ def _apply_constant_op( self_const, factor, needs_data_reindex = self._align_constant( factor, fill_value=fill_value, join=join ) + factor = factor.fillna(fill_value) + self_const = self_const.fillna(0) if needs_data_reindex: fv = {**self._fill_value, "const": 0} data = self.data.reindex_like(self_const, fill_value=fv) return self.__class__( assign_multiindex_safe( - data, coeffs=op(data.coeffs, factor), const=op(self_const, factor) + data, + coeffs=op(data.coeffs.fillna(0), factor), + const=op(self_const, factor), ), self.model, ) - return self.assign(coeffs=op(self.coeffs, factor), const=op(self_const, factor)) + coeffs = self.coeffs.fillna(0) + return self.assign(coeffs=op(coeffs, factor), const=op(self_const, factor)) def _multiply_by_constant( self: GenericExpression, other: ConstantLike, join: JoinOptions | None = None From e422a5e097d340c72cff6b786a6db42fcd8ae7f1 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Tue, 10 Mar 2026 15:24:44 +0100 Subject: [PATCH 7/9] Clean up obsolete code and fix convention-awareness in arithmetic - Remove dead check_common_keys_values function from common.py - Remove redundant default_join parameter from _align_constant, use options["arithmetic_convention"] directly - Gate fillna(0) calls in _add_constant and _apply_constant_op behind legacy convention check so NaN values propagate correctly under v1 - Fix legacy to_constraint path to compute constraint RHS directly instead of routing through sub() which re-applies fillna - Restore Variable.__mul__ scalar fast path via to_linexpr(other) - Restore Variable.__div__ explicit TypeError for non-linear division - Update v1 tests to expect ValueError on mismatched coords and test explicit join= escape hatches Co-Authored-By: Claude Opus 4.6 --- linopy/expressions.py | 36 ++-- linopy/variables.py | 15 +- test/test_constraints.py | 82 +++++---- test/test_linear_expression.py | 305 +++++++++++++++++++++------------ 4 files changed, 273 insertions(+), 165 deletions(-) diff --git a/linopy/expressions.py b/linopy/expressions.py index 418b8d16..c0600852 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -623,14 +623,19 @@ def _align_constant( def _add_constant( self: GenericExpression, other: ConstantLike, join: JoinOptions | None = None ) -> GenericExpression: + is_legacy = ( + join is None and options["arithmetic_convention"] == "legacy" + ) or join == "legacy" if np.isscalar(other) and join is None: - return self.assign(const=self.const.fillna(0) + other) + const = self.const.fillna(0) + other if is_legacy else self.const + other + return self.assign(const=const) da = as_dataarray(other, coords=self.coords, dims=self.coord_dims) self_const, da, needs_data_reindex = self._align_constant( da, fill_value=0, join=join ) - da = da.fillna(0) - self_const = self_const.fillna(0) + if is_legacy: + da = da.fillna(0) + self_const = self_const.fillna(0) if needs_data_reindex: fv = {**self._fill_value, "const": 0} return self.__class__( @@ -648,24 +653,29 @@ def _apply_constant_op( fill_value: float, join: JoinOptions | None = None, ) -> GenericExpression: + is_legacy = ( + join is None and options["arithmetic_convention"] == "legacy" + ) or join == "legacy" factor = as_dataarray(other, coords=self.coords, dims=self.coord_dims) self_const, factor, needs_data_reindex = self._align_constant( factor, fill_value=fill_value, join=join ) - factor = factor.fillna(fill_value) - self_const = self_const.fillna(0) + if is_legacy: + factor = factor.fillna(fill_value) + self_const = self_const.fillna(0) if needs_data_reindex: fv = {**self._fill_value, "const": 0} data = self.data.reindex_like(self_const, fill_value=fv) + coeffs = data.coeffs.fillna(0) if is_legacy else data.coeffs return self.__class__( assign_multiindex_safe( data, - coeffs=op(data.coeffs.fillna(0), factor), + coeffs=op(coeffs, factor), const=op(self_const, factor), ), self.model, ) - coeffs = self.coeffs.fillna(0) + coeffs = self.coeffs.fillna(0) if is_legacy else self.coeffs return self.assign(coeffs=op(coeffs, factor), const=op(self_const, factor)) def _multiply_by_constant( @@ -1185,11 +1195,13 @@ def to_constraint( f"Consider collapsing the dimensions by taking min/max." ) rhs = rhs.reindex_like(self.const, fill_value=np.nan) - all_to_lhs = self.sub(rhs, join=join).data - data = assign_multiindex_safe( - all_to_lhs[["coeffs", "vars"]], sign=sign, rhs=-all_to_lhs.const - ) - return constraints.Constraint(data, model=self.model) + # Alignment already done — compute constraint directly + constraint_rhs = rhs - self.const + data = assign_multiindex_safe( + self.data[["coeffs", "vars"]], sign=sign, rhs=constraint_rhs + ) + return constraints.Constraint(data, model=self.model) + # Non-constant rhs (Variable/Expression) — fall through to sub path if effective_join == "v1": effective_join = "exact" diff --git a/linopy/variables.py b/linopy/variables.py index d80a14bd..991df0da 100644 --- a/linopy/variables.py +++ b/linopy/variables.py @@ -400,6 +400,13 @@ def __mul__(self, other: SideLike) -> ExpressionLike: Multiply variables with a coefficient, variable, or expression. """ try: + if isinstance(other, Variable | ScalarVariable): + return self.to_linexpr() * other + + # Fast path for scalars: build expression directly with coefficient + if np.isscalar(other): + return self.to_linexpr(other) + return self.to_linexpr() * other except TypeError: return NotImplemented @@ -448,7 +455,13 @@ def __div__( """ Divide variables with a coefficient. """ - return self.to_linexpr() / other + if isinstance(other, expressions.LinearExpression | Variable): + raise TypeError( + "unsupported operand type(s) for /: " + f"{type(self)} and {type(other)}. " + "Non-linear expressions are not yet supported." + ) + return self.to_linexpr()._divide_by_constant(other) def __truediv__( self, coefficient: ConstantLike | LinearExpression | Variable diff --git a/test/test_constraints.py b/test/test_constraints.py index 7128af75..4c111e14 100644 --- a/test/test_constraints.py +++ b/test/test_constraints.py @@ -357,67 +357,72 @@ def superset(self, request: Any) -> xr.DataArray | pd.Series: np.arange(25, dtype=float), index=pd.Index(range(25), name="dim_2") ) - def test_var_le_subset(self, v: Variable, subset: xr.DataArray) -> None: - con = v <= subset + def test_var_le_subset_raises(self, v: Variable, subset: xr.DataArray) -> None: + with pytest.raises(ValueError, match="exact"): + v <= subset + + def test_var_le_subset_join_left(self, v: Variable) -> None: + subset_da = xr.DataArray([10.0, 30.0], dims=["dim_2"], coords={"dim_2": [1, 3]}) + con = v.to_linexpr().le(subset_da, join="left") assert con.sizes["dim_2"] == v.sizes["dim_2"] assert con.rhs.sel(dim_2=1).item() == 10.0 assert con.rhs.sel(dim_2=3).item() == 30.0 assert np.isnan(con.rhs.sel(dim_2=0).item()) @pytest.mark.parametrize("sign", [LESS_EQUAL, GREATER_EQUAL, EQUAL]) - def test_var_comparison_subset( + def test_var_comparison_subset_raises( self, v: Variable, subset: xr.DataArray, sign: str ) -> None: - if sign == LESS_EQUAL: - con = v <= subset - elif sign == GREATER_EQUAL: - con = v >= subset - else: - con = v == subset - assert con.sizes["dim_2"] == v.sizes["dim_2"] - assert con.rhs.sel(dim_2=1).item() == 10.0 - assert np.isnan(con.rhs.sel(dim_2=0).item()) + with pytest.raises(ValueError, match="exact"): + if sign == LESS_EQUAL: + v <= subset + elif sign == GREATER_EQUAL: + v >= subset + else: + v == subset + + def test_expr_le_subset_raises(self, v: Variable, subset: xr.DataArray) -> None: + expr = v + 5 + with pytest.raises(ValueError, match="exact"): + expr <= subset - def test_expr_le_subset(self, v: Variable, subset: xr.DataArray) -> None: + def test_expr_le_subset_join_left(self, v: Variable) -> None: + subset_da = xr.DataArray([10.0, 30.0], dims=["dim_2"], coords={"dim_2": [1, 3]}) expr = v + 5 - con = expr <= subset + con = expr.le(subset_da, join="left") assert con.sizes["dim_2"] == v.sizes["dim_2"] assert con.rhs.sel(dim_2=1).item() == pytest.approx(5.0) assert con.rhs.sel(dim_2=3).item() == pytest.approx(25.0) assert np.isnan(con.rhs.sel(dim_2=0).item()) - @pytest.mark.parametrize("sign", [LESS_EQUAL, GREATER_EQUAL, EQUAL]) - def test_subset_comparison_var( - self, v: Variable, subset: xr.DataArray, sign: str + def test_subset_comparison_var_raises( + self, v: Variable, subset: xr.DataArray ) -> None: - if sign == LESS_EQUAL: - con = subset <= v - elif sign == GREATER_EQUAL: - con = subset >= v - else: - con = subset == v - assert con.sizes["dim_2"] == v.sizes["dim_2"] - assert np.isnan(con.rhs.sel(dim_2=0).item()) - assert con.rhs.sel(dim_2=1).item() == pytest.approx(10.0) + with pytest.raises(ValueError, match="exact"): + subset <= v - @pytest.mark.parametrize("sign", [LESS_EQUAL, GREATER_EQUAL]) - def test_superset_comparison_var( - self, v: Variable, superset: xr.DataArray, sign: str + def test_superset_comparison_var_raises( + self, v: Variable, superset: xr.DataArray ) -> None: - if sign == LESS_EQUAL: - con = superset <= v - else: - con = superset >= v - assert con.sizes["dim_2"] == v.sizes["dim_2"] - assert not np.isnan(con.lhs.coeffs.values).any() - assert not np.isnan(con.rhs.values).any() + with pytest.raises(ValueError, match="exact"): + superset <= v - def test_constraint_rhs_extra_dims_broadcasts(self, v: Variable) -> None: + def test_constraint_rhs_extra_dims_raises_on_mismatch(self, v: Variable) -> None: rhs = xr.DataArray( [[1.0, 2.0]], dims=["extra", "dim_2"], coords={"dim_2": [0, 1]}, ) + # dim_2 coords [0,1] don't match v's [0..19] under exact join + with pytest.raises(ValueError, match="exact"): + v <= rhs + + def test_constraint_rhs_extra_dims_broadcasts_matching(self, v: Variable) -> None: + rhs = xr.DataArray( + np.ones((2, 20)), + dims=["extra", "dim_2"], + coords={"dim_2": range(20)}, + ) c = v <= rhs assert "extra" in c.dims @@ -429,7 +434,8 @@ def test_subset_constraint_solve_integration(self) -> None: coords = pd.RangeIndex(5, name="i") x = m.add_variables(lower=0, upper=100, coords=[coords], name="x") subset_ub = xr.DataArray([10.0, 20.0], dims=["i"], coords={"i": [1, 3]}) - m.add_constraints(x <= subset_ub, name="subset_ub") + # exact default raises — use explicit join="left" (NaN = no constraint) + m.add_constraints(x.to_linexpr().le(subset_ub, join="left"), name="subset_ub") m.add_objective(x.sum(), sense="max") m.solve(solver_name=solver) sol = m.solution["x"] diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index 3910ff7a..fe7bc651 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -601,8 +601,24 @@ def nan_constant(self, request: Any) -> xr.DataArray | pd.Series: return pd.Series(vals, index=pd.Index(range(20), name="dim_2")) class TestSubset: + """ + Under v1, subset operations raise ValueError (exact join). + Use explicit join= to recover desired behavior. + """ + @pytest.mark.parametrize("operand", ["var", "expr"]) - def test_mul_subset_fills_zeros( + def test_mul_subset_raises( + self, + v: Variable, + subset: xr.DataArray, + operand: str, + ) -> None: + target = v if operand == "var" else 1 * v + with pytest.raises(ValueError, match="exact"): + target * subset + + @pytest.mark.parametrize("operand", ["var", "expr"]) + def test_mul_subset_join_left( self, v: Variable, subset: xr.DataArray, @@ -610,13 +626,24 @@ def test_mul_subset_fills_zeros( operand: str, ) -> None: target = v if operand == "var" else 1 * v - result = target * subset + result = target.mul(subset, join="left") assert result.sizes["dim_2"] == v.sizes["dim_2"] assert not np.isnan(result.coeffs.values).any() np.testing.assert_array_equal(result.coeffs.squeeze().values, expected_fill) @pytest.mark.parametrize("operand", ["var", "expr"]) - def test_add_subset_fills_zeros( + def test_add_subset_raises( + self, + v: Variable, + subset: xr.DataArray, + operand: str, + ) -> None: + target = v if operand == "var" else v + 5 + with pytest.raises(ValueError, match="exact"): + target + subset + + @pytest.mark.parametrize("operand", ["var", "expr"]) + def test_add_subset_join_left( self, v: Variable, subset: xr.DataArray, @@ -624,17 +651,28 @@ def test_add_subset_fills_zeros( operand: str, ) -> None: if operand == "var": - result = v + subset + result = v.add(subset, join="left") expected = expected_fill else: - result = (v + 5) + subset + result = (v + 5).add(subset, join="left") expected = expected_fill + 5 assert result.sizes["dim_2"] == v.sizes["dim_2"] assert not np.isnan(result.const.values).any() np.testing.assert_array_equal(result.const.values, expected) @pytest.mark.parametrize("operand", ["var", "expr"]) - def test_sub_subset_fills_negated( + def test_sub_subset_raises( + self, + v: Variable, + subset: xr.DataArray, + operand: str, + ) -> None: + target = v if operand == "var" else v + 5 + with pytest.raises(ValueError, match="exact"): + target - subset + + @pytest.mark.parametrize("operand", ["var", "expr"]) + def test_sub_subset_join_left( self, v: Variable, subset: xr.DataArray, @@ -642,195 +680,225 @@ def test_sub_subset_fills_negated( operand: str, ) -> None: if operand == "var": - result = v - subset + result = v.sub(subset, join="left") expected = -expected_fill else: - result = (v + 5) - subset + result = (v + 5).sub(subset, join="left") expected = 5 - expected_fill assert result.sizes["dim_2"] == v.sizes["dim_2"] assert not np.isnan(result.const.values).any() np.testing.assert_array_equal(result.const.values, expected) @pytest.mark.parametrize("operand", ["var", "expr"]) - def test_div_subset_inverts_nonzero( + def test_div_subset_raises( + self, v: Variable, subset: xr.DataArray, operand: str + ) -> None: + target = v if operand == "var" else 1 * v + with pytest.raises(ValueError, match="exact"): + target / subset + + @pytest.mark.parametrize("operand", ["var", "expr"]) + def test_div_subset_join_left( self, v: Variable, subset: xr.DataArray, operand: str ) -> None: target = v if operand == "var" else 1 * v - result = target / subset + result = target.div(subset, join="left") assert result.sizes["dim_2"] == v.sizes["dim_2"] assert not np.isnan(result.coeffs.values).any() assert result.coeffs.squeeze().sel(dim_2=1).item() == pytest.approx(0.1) assert result.coeffs.squeeze().sel(dim_2=0).item() == pytest.approx(1.0) - def test_subset_add_var_coefficients( - self, v: Variable, subset: xr.DataArray - ) -> None: - result = subset + v - np.testing.assert_array_equal(result.coeffs.squeeze().values, np.ones(20)) + def test_subset_add_var_raises(self, v: Variable, subset: xr.DataArray) -> None: + with pytest.raises(ValueError, match="exact"): + subset + v - def test_subset_sub_var_coefficients( - self, v: Variable, subset: xr.DataArray - ) -> None: - result = subset - v - np.testing.assert_array_equal(result.coeffs.squeeze().values, -np.ones(20)) + def test_subset_sub_var_raises(self, v: Variable, subset: xr.DataArray) -> None: + with pytest.raises(ValueError, match="exact"): + subset - v class TestSuperset: - def test_add_superset_pins_to_lhs_coords( + """Under v1, superset operations raise ValueError (exact join).""" + + def test_add_superset_raises(self, v: Variable, superset: xr.DataArray) -> None: + with pytest.raises(ValueError, match="exact"): + v + superset + + def test_add_superset_join_left( self, v: Variable, superset: xr.DataArray ) -> None: - result = v + superset + result = v.add(superset, join="left") assert result.sizes["dim_2"] == v.sizes["dim_2"] assert not np.isnan(result.const.values).any() - def test_add_var_commutative(self, v: Variable, superset: xr.DataArray) -> None: - assert_linequal(superset + v, v + superset) + def test_mul_superset_raises(self, v: Variable, superset: xr.DataArray) -> None: + with pytest.raises(ValueError, match="exact"): + v * superset - def test_sub_var_commutative(self, v: Variable, superset: xr.DataArray) -> None: - assert_linequal(superset - v, -v + superset) - - def test_mul_var_commutative(self, v: Variable, superset: xr.DataArray) -> None: - assert_linequal(superset * v, v * superset) - - def test_mul_superset_pins_to_lhs_coords( + def test_mul_superset_join_inner( self, v: Variable, superset: xr.DataArray ) -> None: - result = v * superset + result = v.mul(superset, join="inner") assert result.sizes["dim_2"] == v.sizes["dim_2"] assert not np.isnan(result.coeffs.values).any() - def test_div_superset_pins_to_lhs_coords(self, v: Variable) -> None: + def test_div_superset_raises(self, v: Variable) -> None: + superset_nonzero = xr.DataArray( + np.arange(1, 26, dtype=float), + dims=["dim_2"], + coords={"dim_2": range(25)}, + ) + with pytest.raises(ValueError, match="exact"): + v / superset_nonzero + + def test_div_superset_join_inner(self, v: Variable) -> None: superset_nonzero = xr.DataArray( np.arange(1, 26, dtype=float), dims=["dim_2"], coords={"dim_2": range(25)}, ) - result = v / superset_nonzero + result = v.div(superset_nonzero, join="inner") assert result.sizes["dim_2"] == v.sizes["dim_2"] assert not np.isnan(result.coeffs.values).any() class TestDisjoint: - def test_add_disjoint_fills_zeros(self, v: Variable) -> None: + """Under v1, disjoint operations raise ValueError (exact join).""" + + def test_add_disjoint_raises(self, v: Variable) -> None: disjoint = xr.DataArray( [100.0, 200.0], dims=["dim_2"], coords={"dim_2": [50, 60]} ) - result = v + disjoint - assert result.sizes["dim_2"] == v.sizes["dim_2"] - assert not np.isnan(result.const.values).any() - np.testing.assert_array_equal(result.const.values, np.zeros(20)) + with pytest.raises(ValueError, match="exact"): + v + disjoint + + def test_add_disjoint_join_outer(self, v: Variable) -> None: + disjoint = xr.DataArray( + [100.0, 200.0], dims=["dim_2"], coords={"dim_2": [50, 60]} + ) + result = v.add(disjoint, join="outer") + assert result.sizes["dim_2"] == 22 # union of [0..19] and [50, 60] + + def test_mul_disjoint_raises(self, v: Variable) -> None: + disjoint = xr.DataArray( + [10.0, 20.0], dims=["dim_2"], coords={"dim_2": [50, 60]} + ) + with pytest.raises(ValueError, match="exact"): + v * disjoint - def test_mul_disjoint_fills_zeros(self, v: Variable) -> None: + def test_mul_disjoint_join_left(self, v: Variable) -> None: disjoint = xr.DataArray( [10.0, 20.0], dims=["dim_2"], coords={"dim_2": [50, 60]} ) - result = v * disjoint + result = v.mul(disjoint, join="left") assert result.sizes["dim_2"] == v.sizes["dim_2"] assert not np.isnan(result.coeffs.values).any() np.testing.assert_array_equal(result.coeffs.squeeze().values, np.zeros(20)) - def test_div_disjoint_preserves_coeffs(self, v: Variable) -> None: + def test_div_disjoint_raises(self, v: Variable) -> None: disjoint = xr.DataArray( [10.0, 20.0], dims=["dim_2"], coords={"dim_2": [50, 60]} ) - result = v / disjoint - assert result.sizes["dim_2"] == v.sizes["dim_2"] - assert not np.isnan(result.coeffs.values).any() - np.testing.assert_array_equal(result.coeffs.squeeze().values, np.ones(20)) + with pytest.raises(ValueError, match="exact"): + v / disjoint class TestCommutativity: - @pytest.mark.parametrize( - "make_lhs,make_rhs", - [ - (lambda v, s: s * v, lambda v, s: v * s), - (lambda v, s: s * (1 * v), lambda v, s: (1 * v) * s), - (lambda v, s: s + v, lambda v, s: v + s), - (lambda v, s: s + (v + 5), lambda v, s: (v + 5) + s), - ], - ids=["subset*var", "subset*expr", "subset+var", "subset+expr"], - ) - def test_commutativity( - self, - v: Variable, - subset: xr.DataArray, - make_lhs: Any, - make_rhs: Any, + """Commutativity tests with matching coordinates under v1.""" + + def test_add_commutativity_matching_coords( + self, v: Variable, matching: xr.DataArray ) -> None: - assert_linequal(make_lhs(v, subset), make_rhs(v, subset)) + assert_linequal(v + matching, matching + v) - def test_sub_var_anticommutative( - self, v: Variable, subset: xr.DataArray + def test_mul_commutativity_matching_coords( + self, v: Variable, matching: xr.DataArray ) -> None: - assert_linequal(subset - v, -v + subset) + assert_linequal(v * matching, matching * v) - def test_sub_expr_anticommutative( + def test_subset_raises_both_sides( self, v: Variable, subset: xr.DataArray ) -> None: - expr = v + 5 - assert_linequal(subset - expr, -(expr - subset)) + """Subset operations raise regardless of operand order.""" + with pytest.raises(ValueError, match="exact"): + v * subset + with pytest.raises(ValueError, match="exact"): + subset * v - def test_add_commutativity_full_coords(self, v: Variable) -> None: - full = xr.DataArray( - np.arange(20, dtype=float), - dims=["dim_2"], - coords={"dim_2": range(20)}, + def test_commutativity_with_join( + self, v: Variable, subset: xr.DataArray + ) -> None: + """Commutativity holds with explicit join.""" + assert_linequal( + v.add(subset, join="inner"), + subset + v.reindex({"dim_2": [1, 3]}), ) - assert_linequal(v + full, full + v) class TestQuadratic: - def test_quadexpr_add_subset( + """Under v1, subset operations on quadratic expressions raise.""" + + def test_quadexpr_add_subset_raises( + self, v: Variable, subset: xr.DataArray + ) -> None: + qexpr = v * v + with pytest.raises(ValueError, match="exact"): + qexpr + subset + + def test_quadexpr_add_subset_join_left( self, v: Variable, subset: xr.DataArray, expected_fill: np.ndarray, ) -> None: qexpr = v * v - result = qexpr + subset + result = qexpr.add(subset, join="left") assert isinstance(result, QuadraticExpression) assert result.sizes["dim_2"] == v.sizes["dim_2"] assert not np.isnan(result.const.values).any() np.testing.assert_array_equal(result.const.values, expected_fill) - def test_quadexpr_sub_subset( + def test_quadexpr_sub_subset_raises( + self, v: Variable, subset: xr.DataArray + ) -> None: + qexpr = v * v + with pytest.raises(ValueError, match="exact"): + qexpr - subset + + def test_quadexpr_sub_subset_join_left( self, v: Variable, subset: xr.DataArray, expected_fill: np.ndarray, ) -> None: qexpr = v * v - result = qexpr - subset + result = qexpr.sub(subset, join="left") assert isinstance(result, QuadraticExpression) assert result.sizes["dim_2"] == v.sizes["dim_2"] assert not np.isnan(result.const.values).any() np.testing.assert_array_equal(result.const.values, -expected_fill) - def test_quadexpr_mul_subset( - self, - v: Variable, - subset: xr.DataArray, - expected_fill: np.ndarray, + def test_quadexpr_mul_subset_raises( + self, v: Variable, subset: xr.DataArray ) -> None: qexpr = v * v - result = qexpr * subset - assert isinstance(result, QuadraticExpression) - assert result.sizes["dim_2"] == v.sizes["dim_2"] - assert not np.isnan(result.coeffs.values).any() - np.testing.assert_array_equal(result.coeffs.squeeze().values, expected_fill) + with pytest.raises(ValueError, match="exact"): + qexpr * subset - def test_subset_mul_quadexpr( + def test_quadexpr_mul_subset_join_left( self, v: Variable, subset: xr.DataArray, expected_fill: np.ndarray, ) -> None: qexpr = v * v - result = subset * qexpr + result = qexpr.mul(subset, join="left") assert isinstance(result, QuadraticExpression) assert result.sizes["dim_2"] == v.sizes["dim_2"] assert not np.isnan(result.coeffs.values).any() np.testing.assert_array_equal(result.coeffs.squeeze().values, expected_fill) - def test_subset_add_quadexpr(self, v: Variable, subset: xr.DataArray) -> None: + def test_quadexpr_add_matching( + self, v: Variable, matching: xr.DataArray + ) -> None: qexpr = v * v - assert_quadequal(subset + qexpr, qexpr + subset) + assert_quadequal(matching + qexpr, qexpr + matching) class TestMissingValues: """Same shape as variable but with NaN entries in the constant.""" @@ -942,17 +1010,30 @@ def test_quadexpr_add_nan( ) class TestMultiDim: - def test_multidim_subset_mul(self, m: Model) -> None: + """Under v1, multi-dim subset operations raise.""" + + def test_multidim_subset_mul_raises(self, m: Model) -> None: coords_a = pd.RangeIndex(4, name="a") coords_b = pd.RangeIndex(5, name="b") w = m.add_variables(coords=[coords_a, coords_b], name="w") + subset_2d = xr.DataArray( + [[2.0, 3.0], [4.0, 5.0]], + dims=["a", "b"], + coords={"a": [1, 3], "b": [0, 4]}, + ) + with pytest.raises(ValueError, match="exact"): + w * subset_2d + def test_multidim_subset_mul_join_left(self, m: Model) -> None: + coords_a = pd.RangeIndex(4, name="a") + coords_b = pd.RangeIndex(5, name="b") + w = m.add_variables(coords=[coords_a, coords_b], name="w") subset_2d = xr.DataArray( [[2.0, 3.0], [4.0, 5.0]], dims=["a", "b"], coords={"a": [1, 3], "b": [0, 4]}, ) - result = w * subset_2d + result = w.mul(subset_2d, join="left") assert result.sizes["a"] == 4 assert result.sizes["b"] == 5 assert not np.isnan(result.coeffs.values).any() @@ -961,23 +1042,17 @@ def test_multidim_subset_mul(self, m: Model) -> None: assert result.coeffs.squeeze().sel(a=0, b=0).item() == pytest.approx(0.0) assert result.coeffs.squeeze().sel(a=1, b=2).item() == pytest.approx(0.0) - def test_multidim_subset_add(self, m: Model) -> None: + def test_multidim_subset_add_raises(self, m: Model) -> None: coords_a = pd.RangeIndex(4, name="a") coords_b = pd.RangeIndex(5, name="b") w = m.add_variables(coords=[coords_a, coords_b], name="w") - subset_2d = xr.DataArray( [[2.0, 3.0], [4.0, 5.0]], dims=["a", "b"], coords={"a": [1, 3], "b": [0, 4]}, ) - result = w + subset_2d - assert result.sizes["a"] == 4 - assert result.sizes["b"] == 5 - assert not np.isnan(result.const.values).any() - assert result.const.sel(a=1, b=0).item() == pytest.approx(2.0) - assert result.const.sel(a=3, b=4).item() == pytest.approx(5.0) - assert result.const.sel(a=0, b=0).item() == pytest.approx(0.0) + with pytest.raises(ValueError, match="exact"): + w + subset_2d class TestXarrayCompat: def test_da_eq_da_still_works(self) -> None: @@ -1847,12 +1922,14 @@ def c(self, m2: Model) -> Variable: return m2.variables["c"] class TestAddition: - def test_add_join_none_preserves_default( + def test_add_join_none_raises_on_mismatch( self, a: Variable, b: Variable ) -> None: - result_default = a.to_linexpr() + b.to_linexpr() - result_none = a.to_linexpr().add(b.to_linexpr(), join=None) - assert_linequal(result_default, result_none) + # a has i=[0,1,2], b has i=[1,2,3] — exact default raises + with pytest.raises(ValueError, match="exact"): + a.to_linexpr() + b.to_linexpr() + with pytest.raises(ValueError, match="exact"): + a.to_linexpr().add(b.to_linexpr(), join=None) def test_add_expr_join_inner(self, a: Variable, b: Variable) -> None: result = a.to_linexpr().add(b.to_linexpr(), join="inner") @@ -2108,12 +2185,12 @@ def test_div_constant_outer_fill_values(self, a: Variable) -> None: class TestQuadratic: def test_quadratic_add_constant_join_inner( - self, a: Variable, b: Variable + self, a: Variable, c: Variable ) -> None: - quad = a.to_linexpr() * b.to_linexpr() + quad = a.to_linexpr() * c.to_linexpr() const = xr.DataArray([10, 20, 30], dims=["i"], coords={"i": [1, 2, 3]}) result = quad.add(const, join="inner") - assert list(result.data.indexes["i"]) == [1, 2, 3] + assert list(result.data.indexes["i"]) == [1, 2] def test_quadratic_add_expr_join_inner(self, a: Variable) -> None: quad = a.to_linexpr() * a.to_linexpr() @@ -2122,9 +2199,9 @@ def test_quadratic_add_expr_join_inner(self, a: Variable) -> None: assert list(result.data.indexes["i"]) == [0, 1] def test_quadratic_mul_constant_join_inner( - self, a: Variable, b: Variable + self, a: Variable, c: Variable ) -> None: - quad = a.to_linexpr() * b.to_linexpr() + quad = a.to_linexpr() * c.to_linexpr() const = xr.DataArray([2, 3, 4], dims=["i"], coords={"i": [1, 2, 3]}) result = quad.mul(const, join="inner") - assert list(result.data.indexes["i"]) == [1, 2, 3] + assert list(result.data.indexes["i"]) == [1, 2] From 2814a3fd9fd2d3f9d99e8212cf809ba59e8933d1 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Tue, 10 Mar 2026 16:29:02 +0100 Subject: [PATCH 8/9] Fix mypy errors and pytest importmode for CI - Add return type annotations (Generator) to all v1_convention fixtures - Add importmode = "importlib" to pytest config to fix import mismatch when linopy is installed from wheel and source dir is also present - Use tuple literal in loop to fix arg-type error Co-Authored-By: Claude Opus 4.6 --- pyproject.toml | 1 + test/conftest.py | 3 ++- test/test_algebraic_properties.py | 4 +++- test/test_common.py | 4 +++- test/test_constraints.py | 3 ++- test/test_linear_expression.py | 3 ++- test/test_linear_expression_legacy.py | 2 +- test/test_typing.py | 4 +++- 8 files changed, 17 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 14a53a22..8f352393 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,6 +101,7 @@ write_to = "linopy/version.py" version_scheme = "no-guess-dev" [tool.pytest.ini_options] +importmode = "importlib" testpaths = ["test"] norecursedirs = ["dev-scripts", "doc", "examples", "benchmark"] markers = [ diff --git a/test/conftest.py b/test/conftest.py index d239531e..7b2669f1 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -3,6 +3,7 @@ from __future__ import annotations import os +from collections.abc import Generator from typing import TYPE_CHECKING import pandas as pd @@ -60,7 +61,7 @@ def pytest_collection_modifyitems( @pytest.fixture -def v1_convention(): +def v1_convention() -> Generator[None, None, None]: """Set arithmetic_convention to 'v1' for the duration of a test.""" linopy.options["arithmetic_convention"] = "v1" yield diff --git a/test/test_algebraic_properties.py b/test/test_algebraic_properties.py index 39d11c3f..04103b61 100644 --- a/test/test_algebraic_properties.py +++ b/test/test_algebraic_properties.py @@ -39,6 +39,8 @@ from __future__ import annotations +from collections.abc import Generator + import numpy as np import pandas as pd import pytest @@ -51,7 +53,7 @@ @pytest.fixture(autouse=True) -def _use_v1_convention(): +def _use_v1_convention() -> Generator[None, None, None]: """Use v1 arithmetic convention for all tests in this module.""" linopy.options["arithmetic_convention"] = "v1" yield diff --git a/test/test_common.py b/test/test_common.py index 6d7192b7..719ab093 100644 --- a/test/test_common.py +++ b/test/test_common.py @@ -5,6 +5,8 @@ @author: fabian """ +from collections.abc import Generator + import numpy as np import pandas as pd import polars as pl @@ -29,7 +31,7 @@ @pytest.fixture(autouse=True) -def _use_v1_convention(): +def _use_v1_convention() -> Generator[None, None, None]: """Use v1 arithmetic convention for all tests in this module.""" linopy.options["arithmetic_convention"] = "v1" yield diff --git a/test/test_constraints.py b/test/test_constraints.py index 4c111e14..e94f0152 100644 --- a/test/test_constraints.py +++ b/test/test_constraints.py @@ -5,6 +5,7 @@ @author: fabulous """ +from collections.abc import Generator from typing import Any import dask @@ -20,7 +21,7 @@ @pytest.fixture(autouse=True) -def _use_v1_convention(): +def _use_v1_convention() -> Generator[None, None, None]: """Use v1 arithmetic convention for all tests in this module.""" linopy.options["arithmetic_convention"] = "v1" yield diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index 9bf2ee80..a4e4abfa 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -7,6 +7,7 @@ from __future__ import annotations +from collections.abc import Generator from typing import Any import numpy as np @@ -25,7 +26,7 @@ @pytest.fixture(autouse=True) -def _use_v1_convention(): +def _use_v1_convention() -> Generator[None, None, None]: """Use v1 arithmetic convention for all tests in this module.""" linopy.options["arithmetic_convention"] = "v1" yield diff --git a/test/test_linear_expression_legacy.py b/test/test_linear_expression_legacy.py index d3b8d426..1378f48d 100644 --- a/test/test_linear_expression_legacy.py +++ b/test/test_linear_expression_legacy.py @@ -1920,7 +1920,7 @@ def test_add_constant_join_override(self, a: Variable, c: Variable) -> None: def test_add_same_coords_all_joins(self, a: Variable, c: Variable) -> None: expr_a = 1 * a + 5 const = xr.DataArray([1, 2, 3], dims=["i"], coords={"i": [0, 1, 2]}) - for join in ["override", "outer", "inner"]: + for join in ("override", "outer", "inner"): result = expr_a.add(const, join=join) assert list(result.coords["i"].values) == [0, 1, 2] np.testing.assert_array_equal(result.const.values, [6, 7, 8]) diff --git a/test/test_typing.py b/test/test_typing.py index 6e7d75f8..566583c2 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -1,3 +1,5 @@ +from collections.abc import Generator + import pytest import xarray as xr @@ -5,7 +7,7 @@ @pytest.fixture(autouse=True) -def _use_v1_convention(): +def _use_v1_convention() -> Generator[None, None, None]: """Use v1 arithmetic convention for all tests in this module.""" linopy.options["arithmetic_convention"] = "v1" yield From 1a5f0bbc127f534b8fe58c1d136c9ff9fbf1f493 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Tue, 10 Mar 2026 16:48:21 +0100 Subject: [PATCH 9/9] Fix CI: move import linopy to lazy in conftest.py Top-level `import linopy` in conftest.py caused pytest to import the package from site-packages before collecting doctests from the source directory, triggering import file mismatch errors on all platforms. Move the import inside fixture functions where it's actually needed. Also revert the unnecessary test.yml and importmode changes. Co-Authored-By: Claude Opus 4.6 --- pyproject.toml | 1 - test/conftest.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8f352393..14a53a22 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,7 +101,6 @@ write_to = "linopy/version.py" version_scheme = "no-guess-dev" [tool.pytest.ini_options] -importmode = "importlib" testpaths = ["test"] norecursedirs = ["dev-scripts", "doc", "examples", "benchmark"] markers = [ diff --git a/test/conftest.py b/test/conftest.py index 7b2669f1..5e2170a3 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -9,8 +9,6 @@ import pandas as pd import pytest -import linopy - if TYPE_CHECKING: from linopy import Model, Variable @@ -63,6 +61,8 @@ def pytest_collection_modifyitems( @pytest.fixture def v1_convention() -> Generator[None, None, None]: """Set arithmetic_convention to 'v1' for the duration of a test.""" + import linopy + linopy.options["arithmetic_convention"] = "v1" yield linopy.options["arithmetic_convention"] = "legacy"