55import sys
66
77import diffrax
8+ import jax
89import pandas as pd
910import petabtests
1011import pytest
2122)
2223from amici .sim .sundials .petab import PetabSimulator
2324from petab import v2
24- import jax
2525
2626logger = get_logger (__name__ , logging .DEBUG )
2727set_log_level (get_logger ("amici.petab_import" ), logging .DEBUG )
@@ -70,7 +70,6 @@ def _test_case(case, model_type, version, jax):
7070 f"petab_{ model_type } _test_case_{ case } _{ version .replace ('.' , '_' )} "
7171 )
7272
73-
7473 if jax :
7574 from amici .jax import petab_simulate , run_simulations
7675 from amici .jax .petab import DEFAULT_CONTROLLER_SETTINGS
@@ -90,28 +89,25 @@ def _test_case(case, model_type, version, jax):
9089
9190 if case .startswith ("0016" ):
9291 controller = diffrax .PIDController (
93- ** DEFAULT_CONTROLLER_SETTINGS ,
94- dtmax = 0.5
92+ ** DEFAULT_CONTROLLER_SETTINGS , dtmax = 0.5
9593 )
9694 else :
97- controller = diffrax .PIDController (
98- ** DEFAULT_CONTROLLER_SETTINGS
99- )
95+ controller = diffrax .PIDController (** DEFAULT_CONTROLLER_SETTINGS )
10096
10197 llh , _ = run_simulations (
102- jax_problem ,
103- steady_state_event = steady_state_event ,
98+ jax_problem ,
99+ steady_state_event = steady_state_event ,
104100 controller = controller ,
105101 )
106102 chi2 , _ = run_simulations (
107- jax_problem ,
108- ret = "chi2" ,
109- steady_state_event = steady_state_event ,
103+ jax_problem ,
104+ ret = "chi2" ,
105+ steady_state_event = steady_state_event ,
110106 controller = controller ,
111107 )
112108 simulation_df = petab_simulate (
113- jax_problem ,
114- steady_state_event = steady_state_event ,
109+ jax_problem ,
110+ steady_state_event = steady_state_event ,
115111 controller = controller ,
116112 )
117113 else :
@@ -137,7 +133,9 @@ def _test_case(case, model_type, version, jax):
137133 )
138134 chi2 = sum (rdata .chi2 for rdata in rdatas )
139135 llh = res .llh
140- simulation_df = rdatas_to_simulation_df (rdatas , ps .model , pi .petab_problem )
136+ simulation_df = rdatas_to_simulation_df (
137+ rdatas , ps .model , pi .petab_problem
138+ )
141139
142140 solution = petabtests .load_solution (case , model_type , version = version )
143141 gt_chi2 = solution [petabtests .CHI2 ]
@@ -198,13 +196,13 @@ def _test_case(case, model_type, version, jax):
198196 else :
199197 if (case , model_type , version ) in (
200198 ("0016" , "sbml" , "v2.0.0" ),
201- ("0024" , "sbml" , "v2.0.0" ),
202- ("0025" , "sbml" , "v2.0.0" ),
203199 ("0013" , "pysb" , "v2.0.0" ),
204200 ):
205201 # FIXME: issue with events and sensitivities
206202 ...
207- else :
203+ elif ps .model .nx_solver > 0 :
204+ # sensitivity calculation is currently only supported for models
205+ # with state variables
208206 check_derivatives (ps , problem_parameters )
209207
210208 if not all ([llhs_match , simulations_match ]) or not chi2s_match :
@@ -247,12 +245,12 @@ def run():
247245 n_total = 0
248246 version = "v2.0.0"
249247
250- for jax in (False , True ):
248+ for jax_ in (False , True ):
251249 cases = list (petabtests .get_cases ("sbml" , version = version ))
252250 n_total += len (cases )
253251 for case in cases :
254252 try :
255- test_case (case , "sbml" , version = version , jax = jax )
253+ test_case (case , "sbml" , version = version , jax = jax_ )
256254 n_success += 1
257255 except Skipped :
258256 n_skipped += 1
0 commit comments