Skip to content

Commit b3c8431

Browse files
Merge branch 'master' into feature/auto-masking1
2 parents 20833e9 + 59f92ae commit b3c8431

9 files changed

Lines changed: 116 additions & 15 deletions

doc/release_notes.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,15 @@ Release Notes
44
Upcoming Version
55
----------------
66

7+
* Fix docs (pick highs solver)
78
* Add the `sphinx-copybutton` to the documentation
89
* Add ``auto_mask`` parameter to ``Model`` class that automatically masks variables and constraints where bounds, coefficients, or RHS values contain NaN. This eliminates the need to manually create mask arrays when working with sparse or incomplete data.
910

11+
Upcoming Version
12+
----------------
13+
14+
* Fix multiplication of constant-only ``LinearExpression`` with other expressions
15+
1016
Version 0.6.1
1117
--------------
1218

examples/create-a-model-with-coordinates.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@
150150
"metadata": {},
151151
"outputs": [],
152152
"source": [
153-
"m.solve()"
153+
"m.solve(solver_name=\"highs\")"
154154
]
155155
},
156156
{

examples/create-a-model.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@
215215
"metadata": {},
216216
"outputs": [],
217217
"source": [
218-
"m.solve()"
218+
"m.solve(solver_name=\"highs\")"
219219
]
220220
},
221221
{

examples/manipulating-models.ipynb

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@
4848
"con2 = m.add_constraints(5 * x + 2 * y >= 3 * factor, name=\"con2\")\n",
4949
"\n",
5050
"m.add_objective(x + 2 * y)\n",
51-
"m.solve()\n",
51+
"m.solve(solver_name=\"highs\")\n",
5252
"\n",
53-
"m.solve()\n",
53+
"m.solve(solver_name=\"highs\")\n",
5454
"sol = m.solution.to_dataframe()\n",
5555
"sol.plot(grid=True, ylabel=\"Optimal Value\")"
5656
]
@@ -95,7 +95,7 @@
9595
"metadata": {},
9696
"outputs": [],
9797
"source": [
98-
"m.solve()\n",
98+
"m.solve(solver_name=\"highs\")\n",
9999
"sol = m.solution.to_dataframe()\n",
100100
"sol.plot(grid=True, ylabel=\"Optimal Value\")"
101101
]
@@ -137,7 +137,7 @@
137137
"metadata": {},
138138
"outputs": [],
139139
"source": [
140-
"m.solve()\n",
140+
"m.solve(solver_name=\"highs\")\n",
141141
"sol = m.solution.to_dataframe()\n",
142142
"sol.plot(grid=True, ylabel=\"Optimal Value\")"
143143
]
@@ -190,7 +190,7 @@
190190
"metadata": {},
191191
"outputs": [],
192192
"source": [
193-
"m.solve()\n",
193+
"m.solve(solver_name=\"highs\")\n",
194194
"sol = m.solution.to_dataframe()\n",
195195
"sol.plot(grid=True, ylabel=\"Optimal Value\")"
196196
]
@@ -242,7 +242,7 @@
242242
"metadata": {},
243243
"outputs": [],
244244
"source": [
245-
"m.solve()\n",
245+
"m.solve(solver_name=\"highs\")\n",
246246
"sol = m.solution.to_dataframe()\n",
247247
"sol.plot(grid=True, ylabel=\"Optimal Value\")"
248248
]
@@ -276,7 +276,7 @@
276276
"metadata": {},
277277
"outputs": [],
278278
"source": [
279-
"m.solve()\n",
279+
"m.solve(solver_name=\"highs\")\n",
280280
"sol = m.solution.to_dataframe()\n",
281281
"sol.plot(grid=True, ylabel=\"Optimal Value\")"
282282
]

examples/solve-on-oetc.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@
169169
"start_time = time.time()\n",
170170
"\n",
171171
"try:\n",
172-
" status, termination_condition = m.solve(remote=oetc_handler)\n",
172+
" status, termination_condition = m.solve(remote=oetc_handler, solver_name=\"highs\")\n",
173173
"\n",
174174
" end_time = time.time()\n",
175175
" total_time = end_time - start_time\n",

examples/transport-tutorial.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@
407407
"cell_type": "markdown",
408408
"metadata": {},
409409
"source": [
410-
"In the `solve()` function, you can specify a `solver_name`. The default solver, however, will be the first from the list we printed above."
410+
"In the `solve()` function, you can specify a `solver_name`. The default solver, however, will be the first from the list we printed above. In this example, we will specify a solver explicitly to avoid licensing issues."
411411
]
412412
},
413413
{
@@ -417,7 +417,7 @@
417417
"outputs": [],
418418
"source": [
419419
"# Solve the model\n",
420-
"m.solve()"
420+
"m.solve(solver_name=\"highs\")"
421421
]
422422
},
423423
{

linopy/expressions.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from collections.abc import Callable, Hashable, Iterator, Mapping, Sequence
1414
from dataclasses import dataclass, field
1515
from itertools import product, zip_longest
16-
from typing import TYPE_CHECKING, Any, TypeVar, overload
16+
from typing import TYPE_CHECKING, Any, TypeVar, cast, overload
1717
from warnings import warn
1818

1919
import numpy as np
@@ -507,12 +507,18 @@ def __neg__(self: GenericExpression) -> GenericExpression:
507507

508508
def _multiply_by_linear_expression(
509509
self, other: LinearExpression | ScalarLinearExpression
510-
) -> QuadraticExpression:
510+
) -> LinearExpression | QuadraticExpression:
511511
if isinstance(other, ScalarLinearExpression):
512512
other = other.to_linexpr()
513513

514514
if other.nterm > 1:
515515
raise TypeError("Multiplication of multiple terms is not supported.")
516+
517+
if other.is_constant:
518+
return cast(LinearExpression, self._multiply_by_constant(other.const))
519+
if self.is_constant:
520+
return cast(LinearExpression, other._multiply_by_constant(self.const))
521+
516522
# multiplication: (v1 + c1) * (v2 + c2) = v1 * v2 + c1 * v2 + c2 * v1 + c1 * c2
517523
# with v being the variables and c the constants
518524
# merge on factor dimension only returns v1 * v2 + c1 * c2

linopy/variables.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from typing import (
1515
TYPE_CHECKING,
1616
Any,
17+
cast,
1718
overload,
1819
)
1920
from warnings import warn
@@ -420,7 +421,9 @@ def __pow__(self, other: int) -> QuadraticExpression:
420421
return NotImplemented
421422
if other == 2:
422423
expr = self.to_linexpr()
423-
return expr._multiply_by_linear_expression(expr)
424+
return cast(
425+
"QuadraticExpression", expr._multiply_by_linear_expression(expr)
426+
)
424427
raise ValueError("Can only raise to the power of 2")
425428

426429
@overload

test/test_linear_expression.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1313,3 +1313,89 @@ def test_simplify_partial_cancellation(x: Variable, y: Variable) -> None:
13131313
assert all(simplified.coeffs.values == 3.0), (
13141314
f"Expected coefficient 3.0, got {simplified.coeffs.values}"
13151315
)
1316+
1317+
1318+
def test_constant_only_expression_mul_dataarray(m: Model) -> None:
1319+
const_arr = xr.DataArray([2, 3], dims=["dim_0"])
1320+
const_expr = LinearExpression(const_arr, m)
1321+
assert const_expr.is_constant
1322+
assert const_expr.nterm == 0
1323+
1324+
data_arr = xr.DataArray([10, 20], dims=["dim_0"])
1325+
expected_const = const_arr * data_arr
1326+
1327+
result = const_expr * data_arr
1328+
assert isinstance(result, LinearExpression)
1329+
assert result.is_constant
1330+
assert (result.const == expected_const).all()
1331+
1332+
result_rev = data_arr * const_expr
1333+
assert isinstance(result_rev, LinearExpression)
1334+
assert result_rev.is_constant
1335+
assert (result_rev.const == expected_const).all()
1336+
1337+
1338+
def test_constant_only_expression_mul_linexpr_with_vars(m: Model, x: Variable) -> None:
1339+
const_arr = xr.DataArray([2, 3], dims=["dim_0"])
1340+
const_expr = LinearExpression(const_arr, m)
1341+
assert const_expr.is_constant
1342+
assert const_expr.nterm == 0
1343+
1344+
expr_with_vars = 1 * x + 5
1345+
expected_coeffs = const_arr
1346+
expected_const = const_arr * 5
1347+
1348+
result = const_expr * expr_with_vars
1349+
assert isinstance(result, LinearExpression)
1350+
assert (result.coeffs == expected_coeffs).all()
1351+
assert (result.const == expected_const).all()
1352+
1353+
result_rev = expr_with_vars * const_expr
1354+
assert isinstance(result_rev, LinearExpression)
1355+
assert (result_rev.coeffs == expected_coeffs).all()
1356+
assert (result_rev.const == expected_const).all()
1357+
1358+
1359+
def test_constant_only_expression_mul_constant_only(m: Model) -> None:
1360+
const_arr = xr.DataArray([2, 3], dims=["dim_0"])
1361+
const_arr2 = xr.DataArray([4, 5], dims=["dim_0"])
1362+
const_expr = LinearExpression(const_arr, m)
1363+
const_expr2 = LinearExpression(const_arr2, m)
1364+
assert const_expr.is_constant
1365+
assert const_expr2.is_constant
1366+
1367+
expected_const = const_arr * const_arr2
1368+
1369+
result = const_expr * const_expr2
1370+
assert isinstance(result, LinearExpression)
1371+
assert result.is_constant
1372+
assert (result.const == expected_const).all()
1373+
1374+
result_rev = const_expr2 * const_expr
1375+
assert isinstance(result_rev, LinearExpression)
1376+
assert result_rev.is_constant
1377+
assert (result_rev.const == expected_const).all()
1378+
1379+
1380+
def test_constant_only_expression_mul_linexpr_with_vars_and_const(
1381+
m: Model, x: Variable
1382+
) -> None:
1383+
const_arr = xr.DataArray([2, 3], dims=["dim_0"])
1384+
const_expr = LinearExpression(const_arr, m)
1385+
assert const_expr.is_constant
1386+
1387+
expr_with_vars_and_const = 4 * x + 10
1388+
expected_coeffs = const_arr * 4
1389+
expected_const = const_arr * 10
1390+
1391+
result = const_expr * expr_with_vars_and_const
1392+
assert isinstance(result, LinearExpression)
1393+
assert not result.is_constant
1394+
assert (result.coeffs == expected_coeffs).all()
1395+
assert (result.const == expected_const).all()
1396+
1397+
result_rev = expr_with_vars_and_const * const_expr
1398+
assert isinstance(result_rev, LinearExpression)
1399+
assert not result_rev.is_constant
1400+
assert (result_rev.coeffs == expected_coeffs).all()
1401+
assert (result_rev.const == expected_const).all()

0 commit comments

Comments
 (0)