Skip to content

Commit 59f92ae

Browse files
Fix multiplication of constant-only LinearExpression (#568)
* Fix multiplication of constant-only LinearExpression When multiplying a constant-only LinearExpression with another expression, the code would fail with IndexError when trying to access _term=0 on an empty term dimension. The fix correctly returns a LinearExpression (not QuadraticExpression) since multiplying by a constant preserves linearity. * fix: add type casts for mypy * fix: use cast instead of isinstance for runtime type check * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent e05518d commit 59f92ae

4 files changed

Lines changed: 103 additions & 3 deletions

File tree

doc/release_notes.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ Upcoming Version
77
* Fix docs (pick highs solver)
88
* Add the `sphinx-copybutton` to the documentation
99

10+
Upcoming Version
11+
----------------
12+
13+
* Fix multiplication of constant-only ``LinearExpression`` with other expressions
14+
1015
Version 0.6.1
1116
--------------
1217

linopy/expressions.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from collections.abc import Callable, Hashable, Iterator, Mapping, Sequence
1414
from dataclasses import dataclass, field
1515
from itertools import product, zip_longest
16-
from typing import TYPE_CHECKING, Any, TypeVar, overload
16+
from typing import TYPE_CHECKING, Any, TypeVar, cast, overload
1717
from warnings import warn
1818

1919
import numpy as np
@@ -507,12 +507,18 @@ def __neg__(self: GenericExpression) -> GenericExpression:
507507

508508
def _multiply_by_linear_expression(
509509
self, other: LinearExpression | ScalarLinearExpression
510-
) -> QuadraticExpression:
510+
) -> LinearExpression | QuadraticExpression:
511511
if isinstance(other, ScalarLinearExpression):
512512
other = other.to_linexpr()
513513

514514
if other.nterm > 1:
515515
raise TypeError("Multiplication of multiple terms is not supported.")
516+
517+
if other.is_constant:
518+
return cast(LinearExpression, self._multiply_by_constant(other.const))
519+
if self.is_constant:
520+
return cast(LinearExpression, other._multiply_by_constant(self.const))
521+
516522
# multiplication: (v1 + c1) * (v2 + c2) = v1 * v2 + c1 * v2 + c2 * v1 + c1 * c2
517523
# with v being the variables and c the constants
518524
# merge on factor dimension only returns v1 * v2 + c1 * c2

linopy/variables.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from typing import (
1515
TYPE_CHECKING,
1616
Any,
17+
cast,
1718
overload,
1819
)
1920
from warnings import warn
@@ -420,7 +421,9 @@ def __pow__(self, other: int) -> QuadraticExpression:
420421
return NotImplemented
421422
if other == 2:
422423
expr = self.to_linexpr()
423-
return expr._multiply_by_linear_expression(expr)
424+
return cast(
425+
"QuadraticExpression", expr._multiply_by_linear_expression(expr)
426+
)
424427
raise ValueError("Can only raise to the power of 2")
425428

426429
@overload

test/test_linear_expression.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1313,3 +1313,89 @@ def test_simplify_partial_cancellation(x: Variable, y: Variable) -> None:
13131313
assert all(simplified.coeffs.values == 3.0), (
13141314
f"Expected coefficient 3.0, got {simplified.coeffs.values}"
13151315
)
1316+
1317+
1318+
def test_constant_only_expression_mul_dataarray(m: Model) -> None:
1319+
const_arr = xr.DataArray([2, 3], dims=["dim_0"])
1320+
const_expr = LinearExpression(const_arr, m)
1321+
assert const_expr.is_constant
1322+
assert const_expr.nterm == 0
1323+
1324+
data_arr = xr.DataArray([10, 20], dims=["dim_0"])
1325+
expected_const = const_arr * data_arr
1326+
1327+
result = const_expr * data_arr
1328+
assert isinstance(result, LinearExpression)
1329+
assert result.is_constant
1330+
assert (result.const == expected_const).all()
1331+
1332+
result_rev = data_arr * const_expr
1333+
assert isinstance(result_rev, LinearExpression)
1334+
assert result_rev.is_constant
1335+
assert (result_rev.const == expected_const).all()
1336+
1337+
1338+
def test_constant_only_expression_mul_linexpr_with_vars(m: Model, x: Variable) -> None:
1339+
const_arr = xr.DataArray([2, 3], dims=["dim_0"])
1340+
const_expr = LinearExpression(const_arr, m)
1341+
assert const_expr.is_constant
1342+
assert const_expr.nterm == 0
1343+
1344+
expr_with_vars = 1 * x + 5
1345+
expected_coeffs = const_arr
1346+
expected_const = const_arr * 5
1347+
1348+
result = const_expr * expr_with_vars
1349+
assert isinstance(result, LinearExpression)
1350+
assert (result.coeffs == expected_coeffs).all()
1351+
assert (result.const == expected_const).all()
1352+
1353+
result_rev = expr_with_vars * const_expr
1354+
assert isinstance(result_rev, LinearExpression)
1355+
assert (result_rev.coeffs == expected_coeffs).all()
1356+
assert (result_rev.const == expected_const).all()
1357+
1358+
1359+
def test_constant_only_expression_mul_constant_only(m: Model) -> None:
1360+
const_arr = xr.DataArray([2, 3], dims=["dim_0"])
1361+
const_arr2 = xr.DataArray([4, 5], dims=["dim_0"])
1362+
const_expr = LinearExpression(const_arr, m)
1363+
const_expr2 = LinearExpression(const_arr2, m)
1364+
assert const_expr.is_constant
1365+
assert const_expr2.is_constant
1366+
1367+
expected_const = const_arr * const_arr2
1368+
1369+
result = const_expr * const_expr2
1370+
assert isinstance(result, LinearExpression)
1371+
assert result.is_constant
1372+
assert (result.const == expected_const).all()
1373+
1374+
result_rev = const_expr2 * const_expr
1375+
assert isinstance(result_rev, LinearExpression)
1376+
assert result_rev.is_constant
1377+
assert (result_rev.const == expected_const).all()
1378+
1379+
1380+
def test_constant_only_expression_mul_linexpr_with_vars_and_const(
1381+
m: Model, x: Variable
1382+
) -> None:
1383+
const_arr = xr.DataArray([2, 3], dims=["dim_0"])
1384+
const_expr = LinearExpression(const_arr, m)
1385+
assert const_expr.is_constant
1386+
1387+
expr_with_vars_and_const = 4 * x + 10
1388+
expected_coeffs = const_arr * 4
1389+
expected_const = const_arr * 10
1390+
1391+
result = const_expr * expr_with_vars_and_const
1392+
assert isinstance(result, LinearExpression)
1393+
assert not result.is_constant
1394+
assert (result.coeffs == expected_coeffs).all()
1395+
assert (result.const == expected_const).all()
1396+
1397+
result_rev = expr_with_vars_and_const * const_expr
1398+
assert isinstance(result_rev, LinearExpression)
1399+
assert not result_rev.is_constant
1400+
assert (result_rev.coeffs == expected_coeffs).all()
1401+
assert (result_rev.const == expected_const).all()

0 commit comments

Comments
 (0)