Skip to content

Commit e16ff3e

Browse files
author
Robbie Muir
committed
fix types
1 parent 9ce2005 commit e16ff3e

4 files changed

Lines changed: 55 additions & 4 deletions

File tree

doc/release_notes.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Release Notes
44
Upcoming Version
55
----------------
66

7+
* Fix the handling of multiplication between ``LinearExpression`` and constants with a subset of dimensions. Align with ``Variable`` behaviour
78
* Fix docs (pick highs solver)
89
* Add the `sphinx-copybutton` to the documentation
910
* Add ``auto_mask`` parameter to ``Model`` class that automatically masks variables and constraints where bounds, coefficients, or RHS values contain NaN. This eliminates the need to manually create mask arrays when working with sparse or incomplete data.

linopy/expressions.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,6 @@ def _multiply_by_constant(
538538
) -> GenericExpression:
539539
multiplier = as_dataarray(other, coords=self.coords, dims=self.coord_dims)
540540
coeffs = self.coeffs * multiplier
541-
assert all(coeffs.sizes[d] == s for d, s in self.coeffs.sizes.items())
542541
const = self.const * multiplier
543542
return self.assign(coeffs=coeffs, const=const)
544543

linopy/testing.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from __future__ import annotations
22

3+
from collections.abc import Iterable
4+
5+
import numpy as np
36
from xarray.testing import assert_equal
47

58
from linopy.constraints import Constraint, _con_unwrap
@@ -72,3 +75,13 @@ def assert_model_equal(a: Model, b: Model) -> None:
7275
assert a.termination_condition == b.termination_condition
7376

7477
assert a.type == b.type
78+
79+
80+
def assert_lists_equal(x: Iterable[float], b: Iterable[float]) -> None:
81+
x = list(x)
82+
b = list(b)
83+
assert len(x) == len(b)
84+
for xi, bi in zip(x, b):
85+
if np.isnan(xi) and np.isnan(bi):
86+
continue
87+
assert xi == bi

test/test_linear_expression.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@
1414
import xarray as xr
1515
from xarray.testing import assert_equal
1616

17-
from linopy import LinearExpression, Model, QuadraticExpression, Variable, merge
17+
from linopy import LinearExpression, Model, QuadraticExpression, Variable
1818
from linopy.constants import HELPER_DIMS, TERM_DIM
19-
from linopy.expressions import ScalarLinearExpression
20-
from linopy.testing import assert_linequal, assert_quadequal
19+
from linopy.expressions import ScalarLinearExpression, merge
20+
from linopy.testing import assert_linequal, assert_lists_equal, assert_quadequal
2121
from linopy.variables import ScalarVariable
2222

2323

@@ -238,6 +238,44 @@ def test_linear_expression_with_multiplication(x: Variable) -> None:
238238
assert expr.__rmul__(object()) is NotImplemented
239239

240240

241+
def test_linear_expression_multiplication_with_missing_coords() -> None:
242+
m = Model()
243+
full_index = pd.Index(range(5), name="i")
244+
x = m.add_variables(coords=[full_index])
245+
nan = float("nan")
246+
scale = xr.DataArray([10.0, 30.0], dims=["i"], coords={"i": [1, 3]})
247+
248+
# These two expressions should produce the same result
249+
r1 = x * scale
250+
r2 = (1 * x) * scale
251+
252+
for result in [r1, r2]:
253+
assert result.coords.equals(x.coords)
254+
assert result.vars.equals(r1.vars)
255+
256+
# Use pandas to make sure nans are considered equal
257+
expected_coeffs = [nan, 10.0, nan, 30.0, nan]
258+
assert_lists_equal(result.coeffs.values.squeeze(), expected_coeffs)
259+
260+
261+
def test_linear_expression_with_missing_coords_in_coeff_and_const() -> None:
262+
m = Model()
263+
full_index = pd.Index(range(5), name="i")
264+
x = m.add_variables(coords=[full_index])
265+
nan = float("nan")
266+
scale = xr.DataArray([10.0, 30.0], dims=["i"], coords={"i": [1, 3]})
267+
const = xr.DataArray([1.0, 2.0], dims=["i"], coords={"i": [0, 1]})
268+
269+
# These two expressions should produce the same result
270+
result = (x + const) * scale
271+
assert result.coords.equals(x.coords)
272+
273+
expected_coeffs = [nan, 10.0, nan, 30.0, nan]
274+
expected_const = [nan, 20.0, nan, 0.0, nan] # Constants are filled with zeros
275+
assert_lists_equal(result.coeffs.values.squeeze(), expected_coeffs)
276+
assert_lists_equal(result.const.values.squeeze(), expected_const)
277+
278+
241279
def test_linear_expression_with_addition(m: Model, x: Variable, y: Variable) -> None:
242280
expr = 10 * x + y
243281
assert isinstance(expr, LinearExpression)

0 commit comments

Comments
 (0)