diff --git a/python/sdist/amici/de_model.py b/python/sdist/amici/de_model.py index 46270d0a50..69dcf5aedf 100644 --- a/python/sdist/amici/de_model.py +++ b/python/sdist/amici/de_model.py @@ -2260,22 +2260,22 @@ def _get_unique_root( def _collect_heaviside_roots( self, args: Sequence[sp.Basic], - ) -> list[sp.Expr]: + ) -> list[tuple[sp.Expr, sp.Expr]]: """ - Recursively checks an expression for the occurrence of Heaviside - functions and return all roots found + Recursively check an expression for the occurrence of Heaviside + functions and return all roots found. :param args: args attribute of the expanded expression :returns: - root functions that were extracted from Heaviside function - arguments + List of (root function, Heaviside x0)-tuples that were extracted + from Heaviside function arguments. """ root_funs = [] for arg in args: if arg.func == sp.Heaviside: - root_funs.append(arg.args[0]) + root_funs.append(arg.args) elif arg.has(sp.Heaviside): root_funs.extend(self._collect_heaviside_roots(arg.args)) @@ -2294,7 +2294,9 @@ def _collect_heaviside_roots( ) ) ) - root_funs = [r.subs(w_sorted) for r in root_funs] + root_funs = [ + (r[0].subs(w_sorted), r[1].subs(w_sorted)) for r in root_funs + ] return root_funs @@ -2322,15 +2324,17 @@ def _process_heavisides( heavisides = [] # run through the expression tree and get the roots tmp_roots_old = self._collect_heaviside_roots((dxdt,)) - for tmp_old in unique_preserve_order(tmp_roots_old): + for tmp_root_old, tmp_x0_old in unique_preserve_order(tmp_roots_old): # we want unique identifiers for the roots - tmp_new = self._get_unique_root(tmp_old, roots) + tmp_root_new = self._get_unique_root(tmp_root_old, roots) # `tmp_new` is None if the root is not time-dependent. - if tmp_new is None: + if tmp_root_new is None: continue # For Heavisides, we need to add the negative function as well - self._get_unique_root(sp.sympify(-tmp_old), roots) - heavisides.append((sp.Heaviside(tmp_old), tmp_new)) + self._get_unique_root(sp.sympify(-tmp_root_old), roots) + heavisides.append( + (sp.Heaviside(tmp_root_old, tmp_x0_old), tmp_root_new) + ) if heavisides: # only apply subs if necessary diff --git a/python/sdist/amici/import_utils.py b/python/sdist/amici/import_utils.py index 26c05feee0..2a99e310bb 100644 --- a/python/sdist/amici/import_utils.py +++ b/python/sdist/amici/import_utils.py @@ -520,16 +520,16 @@ def _parse_heaviside_trigger(trigger: sp.Expr) -> sp.Expr: # step with H(0) = 1 if isinstance(trigger, sp.core.relational.StrictLessThan): # x < y => x - y < 0 => r < 0 - return sp.Integer(1) - sp.Heaviside(root) + return sp.Integer(1) - sp.Heaviside(root, 1) if isinstance(trigger, sp.core.relational.LessThan): # x <= y => not(y < x) => not(y - x < 0) => not -r < 0 - return sp.Heaviside(-root) + return sp.Heaviside(-root, 1) if isinstance(trigger, sp.core.relational.StrictGreaterThan): # y > x => y - x < 0 => -r < 0 - return sp.Integer(1) - sp.Heaviside(-root) + return sp.Integer(1) - sp.Heaviside(-root, 1) if isinstance(trigger, sp.core.relational.GreaterThan): # y >= x => not(x < y) => not(x - y < 0) => not r < 0 - return sp.Heaviside(root) + return sp.Heaviside(root, 1) # rewrite n-ary XOR to OR to be handled below: trigger = trigger.replace(sp.Xor, _xor_to_or) diff --git a/python/tests/test_sbml_import.py b/python/tests/test_sbml_import.py index 342b6146b2..f5d20f6c24 100644 --- a/python/tests/test_sbml_import.py +++ b/python/tests/test_sbml_import.py @@ -15,6 +15,7 @@ from amici.testing import TemporaryDirectoryWinSafe as TemporaryDirectory from amici.testing import skip_on_valgrind from numpy.testing import assert_allclose, assert_array_equal +from amici import import_model_module from conftest import MODEL_STEADYSTATE_SCALED_XML @@ -764,7 +765,6 @@ def test_import_same_model_name(): """Test for error when loading a model with the same extension name as an already loaded model.""" from amici.antimony_import import antimony2amici - from amici import import_model_module # create three versions of a toy model with different parameter values # to detect which model was loaded @@ -871,3 +871,33 @@ def test_regression_2642(): len(np.unique(r.w[:, model.getExpressionIds().index("binding")])) == 1 ) + + +@skip_on_valgrind +def test_regression_2700(): + """Check comparison operators.""" + from amici.antimony_import import antimony2amici + + model_name = "regression_2700" + with TemporaryDirectory(prefix=model_name) as outdir: + antimony2amici( + """ + a = 1 + # condition is always true, so `pp` should be 1 + pp := piecewise(1, a >= 1 && a <= 1, 0) + """, + model_name=model_name, + output_dir=outdir, + ) + + model_module = import_model_module(model_name, outdir) + + model = model_module.get_model() + + model.setTimepoints([0, 1, 2]) + + solver = model.getSolver() + + rdata = amici.runAmiciSimulation(model, solver) + + assert np.all(rdata.by_id("pp") == [1, 1, 1]) diff --git a/src/model.cpp b/src/model.cpp index 9edb149483..faf57f13d8 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -2875,6 +2875,8 @@ void Model::fdwdp(realtype const t, realtype const* x, bool include_static) { } void Model::fdwdx(realtype const t, realtype const* x, bool include_static) { + // NOTE: (at least) Model_{ODE,DAE}::fJSparse rely on `fw` and `fdwdw` + // being called from here. They need to be executed even if nx_solver==0. if (!nw) return; @@ -2882,18 +2884,19 @@ void Model::fdwdx(realtype const t, realtype const* x, bool include_static) { derived_state_.dwdx_.zero(); if (pythonGenerated) { - if (!derived_state_.dwdx_hierarchical_.at(0).capacity()) - return; - fdwdw(t, x, include_static); + auto&& dwdx_hierarchical_0 = derived_state_.dwdx_hierarchical_.at(0); + if (!dwdx_hierarchical_0.data() || !dwdx_hierarchical_0.capacity()) + return; + if (include_static) { - derived_state_.dwdx_hierarchical_.at(0).zero(); - fdwdx_colptrs(derived_state_.dwdx_hierarchical_.at(0)); - fdwdx_rowvals(derived_state_.dwdx_hierarchical_.at(0)); + dwdx_hierarchical_0.zero(); + fdwdx_colptrs(dwdx_hierarchical_0); + fdwdx_rowvals(dwdx_hierarchical_0); } fdwdx( - derived_state_.dwdx_hierarchical_.at(0).data(), t, x, + dwdx_hierarchical_0.data(), t, x, state_.unscaledParameters.data(), state_.fixedParameters.data(), state_.h.data(), derived_state_.w_.data(), state_.total_cl.data(), derived_state_.spl_.data(), include_static