Skip to content

Commit ce1b673

Browse files
authored
SBML import: avoid repeated xdot==0 checks (#2706)
* SBML import: avoid repeated xdot==0 checks Avoid checking for `xdot==0` for every single event when computing `deltasx`. Avoid avoid some extra matrix multiplication, which may result in undefined symbols for some funny models where xdot==0. * trigger SBML test suite on DEModel changes
1 parent 0d689bd commit ce1b673

2 files changed

Lines changed: 13 additions & 5 deletions

File tree

.github/workflows/test_sbml_semantic_test_suite.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ on:
1010
- .github/workflows/test_sbml_semantic_test_suite.yml
1111
- python/sdist/amici/de_export.py
1212
- python/sdist/amici/de_model_components.py
13+
- python/sdist/amici/de_model.py
1314
- python/sdist/amici/sbml_import.py
1415
- python/sdist/amici/import_utils.py
1516
- scripts/run-SBMLTestsuite.sh

python/sdist/amici/de_model.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1693,19 +1693,26 @@ def _compute_equation(self, name: str) -> None:
16931693
]
16941694

16951695
elif name == "deltasx":
1696-
if self.num_states_solver() * self.num_par() == 0:
1696+
if (
1697+
self.num_states_solver() * self.num_par() * self.num_events()
1698+
== 0
1699+
):
16971700
self._eqs[name] = []
16981701
return
16991702

1703+
xdot_is_zero = smart_is_zero_matrix(self.eq("xdot"))
1704+
17001705
event_eqs = []
17011706
for ie, event in enumerate(self._events):
17021707
tmp_eq = sp.zeros(self.num_states_solver(), self.num_par())
17031708

17041709
# need to check if equations are zero since we are using
17051710
# symbols
1706-
if not smart_is_zero_matrix(
1707-
self.eq("stau")[ie]
1708-
) and not smart_is_zero_matrix(self.eq("xdot")):
1711+
1712+
if (
1713+
not smart_is_zero_matrix(self.eq("stau")[ie])
1714+
and not xdot_is_zero
1715+
):
17091716
tmp_eq += smart_multiply(
17101717
self.sym("xdot") - self.sym("xdot_old"),
17111718
self.sym("stau").T,
@@ -1739,7 +1746,7 @@ def _compute_equation(self, name: str) -> None:
17391746
self.eq("ddeltaxdx")[ie], tmp_dxdp
17401747
)
17411748

1742-
else:
1749+
elif not xdot_is_zero:
17431750
tmp_eq = smart_multiply(
17441751
self.sym("xdot") - self.sym("xdot_old"),
17451752
self.eq("stau")[ie],

0 commit comments

Comments
 (0)