Skip to content

Commit 1ad4ee0

Browse files
authored
CI: Update libpetab, fix failing petab v2 tests (#3124)
Enable previously skipped tests. Skip gradient check for nx=0 models.
1 parent b2699a8 commit 1ad4ee0

2 files changed

Lines changed: 20 additions & 22 deletions

File tree

.github/workflows/test_petab_test_suite.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ jobs:
8787
run: |
8888
source ./venv/bin/activate \
8989
&& python3 -m pip uninstall -y petab \
90-
&& python3 -m pip install git+https://github.com/petab-dev/libpetab-python.git@8dc6c1c4b801fba5acc35fcd25308a659d01050e \
90+
&& python3 -m pip install git+https://github.com/petab-dev/libpetab-python.git@44c8062ce1b87a74a0ba1bd2551de0cdc2a13ff1 \
9191
&& python3 -m pip install git+https://github.com/pysb/pysb@master \
9292
&& python3 -m pip install sympy>=1.12.1
9393
@@ -186,7 +186,7 @@ jobs:
186186
run: |
187187
source ./venv/bin/activate \
188188
&& python3 -m pip uninstall -y petab \
189-
&& python3 -m pip install git+https://github.com/petab-dev/libpetab-python.git@d57d9fed8d8d5f8592e76d0b15676e05397c3b4b \
189+
&& python3 -m pip install git+https://github.com/petab-dev/libpetab-python.git@44c8062ce1b87a74a0ba1bd2551de0cdc2a13ff1 \
190190
&& python3 -m pip install git+https://github.com/pysb/pysb@master \
191191
&& python3 -m pip install sympy>=1.12.1
192192

tests/petab_test_suite/test_petab_v2_suite.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import sys
66

77
import diffrax
8+
import jax
89
import pandas as pd
910
import petabtests
1011
import pytest
@@ -21,7 +22,6 @@
2122
)
2223
from amici.sim.sundials.petab import PetabSimulator
2324
from petab import v2
24-
import jax
2525

2626
logger = get_logger(__name__, logging.DEBUG)
2727
set_log_level(get_logger("amici.petab_import"), logging.DEBUG)
@@ -70,7 +70,6 @@ def _test_case(case, model_type, version, jax):
7070
f"petab_{model_type}_test_case_{case}_{version.replace('.', '_')}"
7171
)
7272

73-
7473
if jax:
7574
from amici.jax import petab_simulate, run_simulations
7675
from amici.jax.petab import DEFAULT_CONTROLLER_SETTINGS
@@ -90,28 +89,25 @@ def _test_case(case, model_type, version, jax):
9089

9190
if case.startswith("0016"):
9291
controller = diffrax.PIDController(
93-
**DEFAULT_CONTROLLER_SETTINGS,
94-
dtmax=0.5
92+
**DEFAULT_CONTROLLER_SETTINGS, dtmax=0.5
9593
)
9694
else:
97-
controller = diffrax.PIDController(
98-
**DEFAULT_CONTROLLER_SETTINGS
99-
)
95+
controller = diffrax.PIDController(**DEFAULT_CONTROLLER_SETTINGS)
10096

10197
llh, _ = run_simulations(
102-
jax_problem,
103-
steady_state_event=steady_state_event,
98+
jax_problem,
99+
steady_state_event=steady_state_event,
104100
controller=controller,
105101
)
106102
chi2, _ = run_simulations(
107-
jax_problem,
108-
ret="chi2",
109-
steady_state_event=steady_state_event,
103+
jax_problem,
104+
ret="chi2",
105+
steady_state_event=steady_state_event,
110106
controller=controller,
111107
)
112108
simulation_df = petab_simulate(
113-
jax_problem,
114-
steady_state_event=steady_state_event,
109+
jax_problem,
110+
steady_state_event=steady_state_event,
115111
controller=controller,
116112
)
117113
else:
@@ -137,7 +133,9 @@ def _test_case(case, model_type, version, jax):
137133
)
138134
chi2 = sum(rdata.chi2 for rdata in rdatas)
139135
llh = res.llh
140-
simulation_df = rdatas_to_simulation_df(rdatas, ps.model, pi.petab_problem)
136+
simulation_df = rdatas_to_simulation_df(
137+
rdatas, ps.model, pi.petab_problem
138+
)
141139

142140
solution = petabtests.load_solution(case, model_type, version=version)
143141
gt_chi2 = solution[petabtests.CHI2]
@@ -198,13 +196,13 @@ def _test_case(case, model_type, version, jax):
198196
else:
199197
if (case, model_type, version) in (
200198
("0016", "sbml", "v2.0.0"),
201-
("0024", "sbml", "v2.0.0"),
202-
("0025", "sbml", "v2.0.0"),
203199
("0013", "pysb", "v2.0.0"),
204200
):
205201
# FIXME: issue with events and sensitivities
206202
...
207-
else:
203+
elif ps.model.nx_solver > 0:
204+
# sensitivity calculation is currently only supported for models
205+
# with state variables
208206
check_derivatives(ps, problem_parameters)
209207

210208
if not all([llhs_match, simulations_match]) or not chi2s_match:
@@ -247,12 +245,12 @@ def run():
247245
n_total = 0
248246
version = "v2.0.0"
249247

250-
for jax in (False, True):
248+
for jax_ in (False, True):
251249
cases = list(petabtests.get_cases("sbml", version=version))
252250
n_total += len(cases)
253251
for case in cases:
254252
try:
255-
test_case(case, "sbml", version=version, jax=jax)
253+
test_case(case, "sbml", version=version, jax=jax_)
256254
n_success += 1
257255
except Skipped:
258256
n_skipped += 1

0 commit comments

Comments
 (0)