Skip to content

Commit d0a1f20

Browse files
authored
Fix Heaviside functions for <, >, >= and <= (#2701)
* Fix Heaviside functions for <, >, >= and <= The value at 0 for the Heaviside function was incorrectly set to 1/2, leading to incorrect simulation results. Fixes #2700. * handle second Heaviside arg in replacement * Fix crash for models without state variables Previously, `nx_solver=0` would have in a crash when calling `amici::SUNMatrixWrapper::capacity` on `SUNMatrixWrapper` of 0-sized sparse matrix. Unrelated: Replace some unnecessary `std::vector::at` calls. * add test
1 parent ce1b673 commit d0a1f20

3 files changed

Lines changed: 51 additions & 17 deletions

File tree

python/sdist/amici/de_model.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2267,22 +2267,22 @@ def _get_unique_root(
22672267
def _collect_heaviside_roots(
22682268
self,
22692269
args: Sequence[sp.Basic],
2270-
) -> list[sp.Expr]:
2270+
) -> list[tuple[sp.Expr, sp.Expr]]:
22712271
"""
2272-
Recursively checks an expression for the occurrence of Heaviside
2273-
functions and return all roots found
2272+
Recursively check an expression for the occurrence of Heaviside
2273+
functions and return all roots found.
22742274
22752275
:param args:
22762276
args attribute of the expanded expression
22772277
22782278
:returns:
2279-
root functions that were extracted from Heaviside function
2280-
arguments
2279+
List of (root function, Heaviside x0)-tuples that were extracted
2280+
from Heaviside function arguments.
22812281
"""
22822282
root_funs = []
22832283
for arg in args:
22842284
if arg.func == sp.Heaviside:
2285-
root_funs.append(arg.args[0])
2285+
root_funs.append(arg.args)
22862286
elif arg.has(sp.Heaviside):
22872287
root_funs.extend(self._collect_heaviside_roots(arg.args))
22882288

@@ -2301,7 +2301,9 @@ def _collect_heaviside_roots(
23012301
)
23022302
)
23032303
)
2304-
root_funs = [r.subs(w_sorted) for r in root_funs]
2304+
root_funs = [
2305+
(r[0].subs(w_sorted), r[1].subs(w_sorted)) for r in root_funs
2306+
]
23052307

23062308
return root_funs
23072309

@@ -2329,15 +2331,17 @@ def _process_heavisides(
23292331
heavisides = []
23302332
# run through the expression tree and get the roots
23312333
tmp_roots_old = self._collect_heaviside_roots((dxdt,))
2332-
for tmp_old in unique_preserve_order(tmp_roots_old):
2334+
for tmp_root_old, tmp_x0_old in unique_preserve_order(tmp_roots_old):
23332335
# we want unique identifiers for the roots
2334-
tmp_new = self._get_unique_root(tmp_old, roots)
2336+
tmp_root_new = self._get_unique_root(tmp_root_old, roots)
23352337
# `tmp_new` is None if the root is not time-dependent.
2336-
if tmp_new is None:
2338+
if tmp_root_new is None:
23372339
continue
23382340
# For Heavisides, we need to add the negative function as well
2339-
self._get_unique_root(sp.sympify(-tmp_old), roots)
2340-
heavisides.append((sp.Heaviside(tmp_old), tmp_new))
2341+
self._get_unique_root(sp.sympify(-tmp_root_old), roots)
2342+
heavisides.append(
2343+
(sp.Heaviside(tmp_root_old, tmp_x0_old), tmp_root_new)
2344+
)
23412345

23422346
if heavisides:
23432347
# only apply subs if necessary

python/sdist/amici/import_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -520,16 +520,16 @@ def _parse_heaviside_trigger(trigger: sp.Expr) -> sp.Expr:
520520
# step with H(0) = 1
521521
if isinstance(trigger, sp.core.relational.StrictLessThan):
522522
# x < y => x - y < 0 => r < 0
523-
return sp.Integer(1) - sp.Heaviside(root)
523+
return sp.Integer(1) - sp.Heaviside(root, 1)
524524
if isinstance(trigger, sp.core.relational.LessThan):
525525
# x <= y => not(y < x) => not(y - x < 0) => not -r < 0
526-
return sp.Heaviside(-root)
526+
return sp.Heaviside(-root, 1)
527527
if isinstance(trigger, sp.core.relational.StrictGreaterThan):
528528
# y > x => y - x < 0 => -r < 0
529-
return sp.Integer(1) - sp.Heaviside(-root)
529+
return sp.Integer(1) - sp.Heaviside(-root, 1)
530530
if isinstance(trigger, sp.core.relational.GreaterThan):
531531
# y >= x => not(x < y) => not(x - y < 0) => not r < 0
532-
return sp.Heaviside(root)
532+
return sp.Heaviside(root, 1)
533533

534534
# rewrite n-ary XOR to OR to be handled below:
535535
trigger = trigger.replace(sp.Xor, _xor_to_or)

python/tests/test_sbml_import.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from amici.testing import TemporaryDirectoryWinSafe as TemporaryDirectory
1616
from amici.testing import skip_on_valgrind
1717
from numpy.testing import assert_allclose, assert_array_equal
18+
from amici import import_model_module
1819

1920
from conftest import MODEL_STEADYSTATE_SCALED_XML
2021

@@ -764,7 +765,6 @@ def test_import_same_model_name():
764765
"""Test for error when loading a model with the same extension name as an
765766
already loaded model."""
766767
from amici.antimony_import import antimony2amici
767-
from amici import import_model_module
768768

769769
# create three versions of a toy model with different parameter values
770770
# to detect which model was loaded
@@ -871,3 +871,33 @@ def test_regression_2642():
871871
len(np.unique(r.w[:, model.getExpressionIds().index("binding")]))
872872
== 1
873873
)
874+
875+
876+
@skip_on_valgrind
877+
def test_regression_2700():
878+
"""Check comparison operators."""
879+
from amici.antimony_import import antimony2amici
880+
881+
model_name = "regression_2700"
882+
with TemporaryDirectory(prefix=model_name) as outdir:
883+
antimony2amici(
884+
"""
885+
a = 1
886+
# condition is always true, so `pp` should be 1
887+
pp := piecewise(1, a >= 1 && a <= 1, 0)
888+
""",
889+
model_name=model_name,
890+
output_dir=outdir,
891+
)
892+
893+
model_module = import_model_module(model_name, outdir)
894+
895+
model = model_module.get_model()
896+
897+
model.setTimepoints([0, 1, 2])
898+
899+
solver = model.getSolver()
900+
901+
rdata = amici.runAmiciSimulation(model, solver)
902+
903+
assert np.all(rdata.by_id("pp") == [1, 1, 1])

0 commit comments

Comments
 (0)