@@ -2260,22 +2260,22 @@ def _get_unique_root(
22602260 def _collect_heaviside_roots (
22612261 self ,
22622262 args : Sequence [sp .Basic ],
2263- ) -> list [sp .Expr ]:
2263+ ) -> list [tuple [ sp .Expr , sp . Expr ] ]:
22642264 """
2265- Recursively checks an expression for the occurrence of Heaviside
2266- functions and return all roots found
2265+ Recursively check an expression for the occurrence of Heaviside
2266+ functions and return all roots found.
22672267
22682268 :param args:
22692269 args attribute of the expanded expression
22702270
22712271 :returns:
2272- root functions that were extracted from Heaviside function
2273- arguments
2272+ List of ( root function, Heaviside x0)-tuples that were extracted
2273+ from Heaviside function arguments.
22742274 """
22752275 root_funs = []
22762276 for arg in args :
22772277 if arg .func == sp .Heaviside :
2278- root_funs .append (arg .args [ 0 ] )
2278+ root_funs .append (arg .args )
22792279 elif arg .has (sp .Heaviside ):
22802280 root_funs .extend (self ._collect_heaviside_roots (arg .args ))
22812281
@@ -2294,7 +2294,9 @@ def _collect_heaviside_roots(
22942294 )
22952295 )
22962296 )
2297- root_funs = [r .subs (w_sorted ) for r in root_funs ]
2297+ root_funs = [
2298+ (r [0 ].subs (w_sorted ), r [1 ].subs (w_sorted )) for r in root_funs
2299+ ]
22982300
22992301 return root_funs
23002302
@@ -2322,15 +2324,17 @@ def _process_heavisides(
23222324 heavisides = []
23232325 # run through the expression tree and get the roots
23242326 tmp_roots_old = self ._collect_heaviside_roots ((dxdt ,))
2325- for tmp_old in unique_preserve_order (tmp_roots_old ):
2327+ for tmp_root_old , tmp_x0_old in unique_preserve_order (tmp_roots_old ):
23262328 # we want unique identifiers for the roots
2327- tmp_new = self ._get_unique_root (tmp_old , roots )
2329+ tmp_root_new = self ._get_unique_root (tmp_root_old , roots )
23282330 # `tmp_new` is None if the root is not time-dependent.
2329- if tmp_new is None :
2331+ if tmp_root_new is None :
23302332 continue
23312333 # For Heavisides, we need to add the negative function as well
2332- self ._get_unique_root (sp .sympify (- tmp_old ), roots )
2333- heavisides .append ((sp .Heaviside (tmp_old ), tmp_new ))
2334+ self ._get_unique_root (sp .sympify (- tmp_root_old ), roots )
2335+ heavisides .append (
2336+ (sp .Heaviside (tmp_root_old , tmp_x0_old ), tmp_root_new )
2337+ )
23342338
23352339 if heavisides :
23362340 # only apply subs if necessary
0 commit comments