Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions python/sdist/amici/de_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2260,22 +2260,22 @@ def _get_unique_root(
def _collect_heaviside_roots(
self,
args: Sequence[sp.Basic],
) -> list[sp.Expr]:
) -> list[tuple[sp.Expr, sp.Expr]]:
"""
Recursively checks an expression for the occurrence of Heaviside
functions and return all roots found
Recursively check an expression for the occurrence of Heaviside
functions and return all roots found.

:param args:
args attribute of the expanded expression

:returns:
root functions that were extracted from Heaviside function
arguments
List of (root function, Heaviside x0)-tuples that were extracted
from Heaviside function arguments.
"""
root_funs = []
for arg in args:
if arg.func == sp.Heaviside:
root_funs.append(arg.args[0])
root_funs.append(arg.args)
elif arg.has(sp.Heaviside):
root_funs.extend(self._collect_heaviside_roots(arg.args))

Expand All @@ -2294,7 +2294,9 @@ def _collect_heaviside_roots(
)
)
)
root_funs = [r.subs(w_sorted) for r in root_funs]
root_funs = [
(r[0].subs(w_sorted), r[1].subs(w_sorted)) for r in root_funs
]

return root_funs

Expand Down Expand Up @@ -2322,15 +2324,17 @@ def _process_heavisides(
heavisides = []
# run through the expression tree and get the roots
tmp_roots_old = self._collect_heaviside_roots((dxdt,))
for tmp_old in unique_preserve_order(tmp_roots_old):
for tmp_root_old, tmp_x0_old in unique_preserve_order(tmp_roots_old):
# we want unique identifiers for the roots
tmp_new = self._get_unique_root(tmp_old, roots)
tmp_root_new = self._get_unique_root(tmp_root_old, roots)
# `tmp_new` is None if the root is not time-dependent.
if tmp_new is None:
if tmp_root_new is None:
continue
# For Heavisides, we need to add the negative function as well
self._get_unique_root(sp.sympify(-tmp_old), roots)
heavisides.append((sp.Heaviside(tmp_old), tmp_new))
self._get_unique_root(sp.sympify(-tmp_root_old), roots)
heavisides.append(
(sp.Heaviside(tmp_root_old, tmp_x0_old), tmp_root_new)
)

if heavisides:
# only apply subs if necessary
Expand Down
8 changes: 4 additions & 4 deletions python/sdist/amici/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,16 +520,16 @@
# step with H(0) = 1
if isinstance(trigger, sp.core.relational.StrictLessThan):
# x < y => x - y < 0 => r < 0
return sp.Integer(1) - sp.Heaviside(root)
return sp.Integer(1) - sp.Heaviside(root, 1)
if isinstance(trigger, sp.core.relational.LessThan):
# x <= y => not(y < x) => not(y - x < 0) => not -r < 0
return sp.Heaviside(-root)
return sp.Heaviside(-root, 1)

Check warning on line 526 in python/sdist/amici/import_utils.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/import_utils.py#L526

Added line #L526 was not covered by tests
if isinstance(trigger, sp.core.relational.StrictGreaterThan):
# y > x => y - x < 0 => -r < 0
return sp.Integer(1) - sp.Heaviside(-root)
return sp.Integer(1) - sp.Heaviside(-root, 1)

Check warning on line 529 in python/sdist/amici/import_utils.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/import_utils.py#L529

Added line #L529 was not covered by tests
if isinstance(trigger, sp.core.relational.GreaterThan):
# y >= x => not(x < y) => not(x - y < 0) => not r < 0
return sp.Heaviside(root)
return sp.Heaviside(root, 1)

Check warning on line 532 in python/sdist/amici/import_utils.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/import_utils.py#L532

Added line #L532 was not covered by tests

# rewrite n-ary XOR to OR to be handled below:
trigger = trigger.replace(sp.Xor, _xor_to_or)
Expand Down
32 changes: 31 additions & 1 deletion python/tests/test_sbml_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from amici.testing import TemporaryDirectoryWinSafe as TemporaryDirectory
from amici.testing import skip_on_valgrind
from numpy.testing import assert_allclose, assert_array_equal
from amici import import_model_module

from conftest import MODEL_STEADYSTATE_SCALED_XML

Expand Down Expand Up @@ -764,7 +765,6 @@ def test_import_same_model_name():
"""Test for error when loading a model with the same extension name as an
already loaded model."""
from amici.antimony_import import antimony2amici
from amici import import_model_module

# create three versions of a toy model with different parameter values
# to detect which model was loaded
Expand Down Expand Up @@ -871,3 +871,33 @@ def test_regression_2642():
len(np.unique(r.w[:, model.getExpressionIds().index("binding")]))
== 1
)


@skip_on_valgrind
def test_regression_2700():
"""Check comparison operators."""
from amici.antimony_import import antimony2amici

model_name = "regression_2700"
with TemporaryDirectory(prefix=model_name) as outdir:
antimony2amici(
"""
a = 1
# condition is always true, so `pp` should be 1
pp := piecewise(1, a >= 1 && a <= 1, 0)
""",
model_name=model_name,
output_dir=outdir,
)

model_module = import_model_module(model_name, outdir)

model = model_module.get_model()

model.setTimepoints([0, 1, 2])

solver = model.getSolver()

rdata = amici.runAmiciSimulation(model, solver)

assert np.all(rdata.by_id("pp") == [1, 1, 1])
17 changes: 10 additions & 7 deletions src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2875,25 +2875,28 @@ void Model::fdwdp(realtype const t, realtype const* x, bool include_static) {
}

void Model::fdwdx(realtype const t, realtype const* x, bool include_static) {
// NOTE: (at least) Model_{ODE,DAE}::fJSparse rely on `fw` and `fdwdw`
// being called from here. They need to be executed even if nx_solver==0.
if (!nw)
return;

fw(t, x, include_static);

derived_state_.dwdx_.zero();
if (pythonGenerated) {
if (!derived_state_.dwdx_hierarchical_.at(0).capacity())
return;

fdwdw(t, x, include_static);

auto&& dwdx_hierarchical_0 = derived_state_.dwdx_hierarchical_.at(0);
if (!dwdx_hierarchical_0.data() || !dwdx_hierarchical_0.capacity())
return;

if (include_static) {
derived_state_.dwdx_hierarchical_.at(0).zero();
fdwdx_colptrs(derived_state_.dwdx_hierarchical_.at(0));
fdwdx_rowvals(derived_state_.dwdx_hierarchical_.at(0));
dwdx_hierarchical_0.zero();
fdwdx_colptrs(dwdx_hierarchical_0);
fdwdx_rowvals(dwdx_hierarchical_0);
}
fdwdx(
derived_state_.dwdx_hierarchical_.at(0).data(), t, x,
dwdx_hierarchical_0.data(), t, x,
state_.unscaledParameters.data(), state_.fixedParameters.data(),
state_.h.data(), derived_state_.w_.data(), state_.total_cl.data(),
derived_state_.spl_.data(), include_static
Expand Down
Loading