Skip to content

Commit c84db32

Browse files
authored
Replace min/max disguised as piecewise (#3053)
During model import, simplify piecewise functions that can be expressed as min()/max(). So far, all piecewise constructs are treated as discontinuous (see also #2049). For min/max, only the derivative is discontinuous, which currently does not receive any special treatment in amici (see also https://amici.readthedocs.io/en/latest/implementation_discontinuities.html). So this change prevents root tracking for some unnecessary piecewises. Addresses some aspects of #2049.
1 parent 1293e9d commit c84db32

4 files changed

Lines changed: 90 additions & 1 deletion

File tree

python/sdist/amici/importers/sbml/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
from amici.logging import get_logger, log_execution_time, set_log_level
6767
from amici.sympy_utils import (
6868
_monkeypatch_sympy,
69+
_piecewise_to_minmax,
6970
smart_is_zero_matrix,
7071
smart_multiply,
7172
)
@@ -2882,6 +2883,7 @@ def subs_locals(expr: sp.Basic) -> sp.Basic:
28822883
# piecewise to heavisides
28832884
if piecewise_to_heaviside:
28842885
try:
2886+
expr = expr.replace(sp.Piecewise, _piecewise_to_minmax)
28852887
expr = expr.replace(
28862888
sp.Piecewise,
28872889
lambda *args: _parse_piecewise_to_heaviside(args),

python/sdist/amici/sympy_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,3 +216,19 @@ def _parallel_applyfunc(obj: sp.Matrix, func: Callable) -> sp.Matrix:
216216
"to a module-level function or disable parallelization by "
217217
"setting `AMICI_IMPORT_NPROCS=1`."
218218
) from e
219+
220+
221+
def _piecewise_to_minmax(
222+
*expr_cond_pairs: tuple[tuple[sp.Basic, sp.Basic], ...],
223+
) -> sp.Basic:
224+
"""Replace min/max defined via Piecewise with plain Min/Max.
225+
226+
To be used in ``expr = expr.replace(sp.Piecewise, pw_to_minmax)``.
227+
"""
228+
if len(expr_cond_pairs) == 2 and expr_cond_pairs[-1][1] == sp.true:
229+
(expr1, cond1), (expr2, cond2) = expr_cond_pairs
230+
if cond1.args == (expr1, expr2) and cond1.func in (sp.Lt, sp.Le):
231+
return sp.Min(expr1, expr2)
232+
elif cond1.args == (expr1, expr2) and cond1.func in (sp.Gt, sp.Ge):
233+
return sp.Max(expr1, expr2)
234+
return sp.Piecewise(*expr_cond_pairs)

python/tests/test_sbml_import.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import sympy as sp
1414
from amici import import_model_module
1515
from amici.gradient_check import check_derivatives
16+
from amici.importers.antimony import antimony2sbml
1617
from amici.importers.sbml import SbmlImporter, SymbolId
1718
from amici.importers.utils import (
1819
MeasurementChannel as MC,
@@ -1196,3 +1197,21 @@ def test_time_dependent_initial_assignment(compute_conservation_laws: bool):
11961197
symbol_with_assumptions("p0"),
11971198
amici_time_symbol * 1.0 + 3.0,
11981199
]
1200+
1201+
1202+
@skip_on_valgrind
1203+
def test_minmax_piecewise_is_converted_to_minmax():
1204+
"""Test that _piecewise_to_minmax is applied during SBML import."""
1205+
sbml_str = antimony2sbml("""
1206+
x' = piecewise(a, a > b, b)
1207+
y' = piecewise(a, a < b, b)
1208+
""")
1209+
sbml_importer = SbmlImporter(sbml_source=sbml_str, from_file=False)
1210+
de_model = sbml_importer._build_ode_model()
1211+
# no events should be created for min/max
1212+
assert not de_model.events()
1213+
assert len(de_model.sym("h")) == 0
1214+
# min/max are present in the equations
1215+
xdot = de_model.eq("xdot")
1216+
assert xdot.has(sp.Min)
1217+
assert xdot.has(sp.Max)

python/tests/test_sympy_utils.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
"""Tests related to the sympy_utils module."""
22

33
import sympy as sp
4-
from amici.sympy_utils import _custom_pow_eval_derivative, _monkeypatched
4+
from amici.sympy_utils import (
5+
_custom_pow_eval_derivative,
6+
_monkeypatched,
7+
_piecewise_to_minmax,
8+
)
59
from amici.testing import skip_on_valgrind
610

711

@@ -22,3 +26,51 @@ def test_monkeypatch():
2226

2327
# check that the monkeypatch is transient
2428
assert (t**n).diff(t).subs(vals) is sp.nan
29+
30+
31+
@skip_on_valgrind
32+
def test_rewrite_piecewise_minmax():
33+
"""Test rewriting of piecewise min/max to sympy Min/Max functions."""
34+
x, y, z = sp.symbols("x y z")
35+
36+
assert sp.Piecewise((x, x < y), (y, True)).replace(
37+
sp.Piecewise, _piecewise_to_minmax
38+
) == sp.Min(x, y)
39+
assert sp.Piecewise((x, x <= y), (y, True)).replace(
40+
sp.Piecewise, _piecewise_to_minmax
41+
) == sp.Min(x, y)
42+
assert sp.Piecewise((x, x > y), (y, True)).replace(
43+
sp.Piecewise, _piecewise_to_minmax
44+
) == sp.Max(x, y)
45+
assert sp.Piecewise((x, x >= y), (y, True)).replace(
46+
sp.Piecewise, _piecewise_to_minmax
47+
) == sp.Max(x, y)
48+
assert sp.Piecewise((x, y > x), (y, True)).replace(
49+
sp.Piecewise, _piecewise_to_minmax
50+
) == sp.Min(x, y)
51+
assert sp.Piecewise((x, y >= x), (y, True)).replace(
52+
sp.Piecewise, _piecewise_to_minmax
53+
) == sp.Min(x, y)
54+
assert sp.Piecewise((x, y < x), (y, True)).replace(
55+
sp.Piecewise, _piecewise_to_minmax
56+
) == sp.Max(x, y)
57+
assert sp.Piecewise((x, y <= x), (y, True)).replace(
58+
sp.Piecewise, _piecewise_to_minmax
59+
) == sp.Max(x, y)
60+
61+
# can't replace
62+
assert sp.Piecewise((z, y <= x), (y, True)).replace(
63+
sp.Piecewise, _piecewise_to_minmax
64+
) == sp.Piecewise((z, y <= x), (y, True))
65+
66+
# replace recursively
67+
expr = sp.Piecewise(
68+
(sp.Piecewise((x, x < y), (y, True)), x < z),
69+
(sp.Piecewise((y, y < z), (z, True)), True),
70+
)
71+
replaced = expr.replace(sp.Piecewise, _piecewise_to_minmax)
72+
expected = sp.Piecewise(
73+
(sp.Min(x, y), x < z),
74+
(sp.Min(y, z), True),
75+
)
76+
assert replaced == expected

0 commit comments

Comments
 (0)