Skip to content

Commit 188b466

Browse files
committed
handle second Heaviside arg in replacement
1 parent 7ae7230 commit 188b466

1 file changed

Lines changed: 16 additions & 12 deletions

File tree

python/sdist/amici/de_model.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2260,22 +2260,22 @@ def _get_unique_root(
22602260
def _collect_heaviside_roots(
22612261
self,
22622262
args: Sequence[sp.Basic],
2263-
) -> list[sp.Expr]:
2263+
) -> list[tuple[sp.Expr, sp.Expr]]:
22642264
"""
2265-
Recursively checks an expression for the occurrence of Heaviside
2266-
functions and return all roots found
2265+
Recursively check an expression for the occurrence of Heaviside
2266+
functions and return all roots found.
22672267
22682268
:param args:
22692269
args attribute of the expanded expression
22702270
22712271
:returns:
2272-
root functions that were extracted from Heaviside function
2273-
arguments
2272+
List of (root function, Heaviside x0)-tuples that were extracted
2273+
from Heaviside function arguments.
22742274
"""
22752275
root_funs = []
22762276
for arg in args:
22772277
if arg.func == sp.Heaviside:
2278-
root_funs.append(arg.args[0])
2278+
root_funs.append(arg.args)
22792279
elif arg.has(sp.Heaviside):
22802280
root_funs.extend(self._collect_heaviside_roots(arg.args))
22812281

@@ -2294,7 +2294,9 @@ def _collect_heaviside_roots(
22942294
)
22952295
)
22962296
)
2297-
root_funs = [r.subs(w_sorted) for r in root_funs]
2297+
root_funs = [
2298+
(r[0].subs(w_sorted), r[1].subs(w_sorted)) for r in root_funs
2299+
]
22982300

22992301
return root_funs
23002302

@@ -2322,15 +2324,17 @@ def _process_heavisides(
23222324
heavisides = []
23232325
# run through the expression tree and get the roots
23242326
tmp_roots_old = self._collect_heaviside_roots((dxdt,))
2325-
for tmp_old in unique_preserve_order(tmp_roots_old):
2327+
for tmp_root_old, tmp_x0_old in unique_preserve_order(tmp_roots_old):
23262328
# we want unique identifiers for the roots
2327-
tmp_new = self._get_unique_root(tmp_old, roots)
2329+
tmp_root_new = self._get_unique_root(tmp_root_old, roots)
23282330
# `tmp_new` is None if the root is not time-dependent.
2329-
if tmp_new is None:
2331+
if tmp_root_new is None:
23302332
continue
23312333
# For Heavisides, we need to add the negative function as well
2332-
self._get_unique_root(sp.sympify(-tmp_old), roots)
2333-
heavisides.append((sp.Heaviside(tmp_old), tmp_new))
2334+
self._get_unique_root(sp.sympify(-tmp_root_old), roots)
2335+
heavisides.append(
2336+
(sp.Heaviside(tmp_root_old, tmp_x0_old), tmp_root_new)
2337+
)
23342338

23352339
if heavisides:
23362340
# only apply subs if necessary

0 commit comments

Comments
 (0)