1313import shutil
1414from pathlib import Path
1515
16+ import optimistix
17+
1618import amici
1719import pandas as pd
1820import pytest
2123import jax .numpy as jnp
2224import numpy as np
2325import diffrax
24- from amici .jax .petab import DEFAULT_CONTROLLER_SETTINGS
26+ from amici .jax .petab import (
27+ DEFAULT_CONTROLLER_SETTINGS ,
28+ DEFAULT_ROOT_FINDER_SETTINGS ,
29+ )
2530
2631from utils import (
2732 verify_results ,
@@ -72,11 +77,6 @@ def test_sbml_testsuite_case(test_id, result_path, sbml_semantic_cases_dir):
7277
7378 atol , rtol = apply_settings (settings , solver , model , test_id )
7479
75- if test_id in sensitivity_check_cases :
76- model .requireSensitivitiesForAllParameters ()
77- solver .setSensitivityOrder (amici .SensitivityOrder .first )
78- solver .setSensitivityMethod (amici .SensitivityMethod .forward )
79-
8080 # simulate model
8181 rdata = amici .runAmiciSimulation (model , solver )
8282 if rdata ["status" ] != amici .AMICI_SUCCESS :
@@ -208,6 +208,7 @@ def jax_sensitivity_check(
208208 icoeff = DEFAULT_CONTROLLER_SETTINGS ["icoeff" ],
209209 dcoeff = DEFAULT_CONTROLLER_SETTINGS ["dcoeff" ],
210210 )
211+ root_finder = optimistix .Newton (** DEFAULT_ROOT_FINDER_SETTINGS )
211212
212213 def simulate (pars ):
213214 x , _ = jax_model .simulate_condition (
@@ -221,6 +222,7 @@ def simulate(pars):
221222 jnp .zeros ((ts_jnp .shape [0 ], 0 )),
222223 solver ,
223224 controller ,
225+ root_finder ,
224226 diffrax .DirectAdjoint (),
225227 diffrax .SteadyStateEvent (),
226228 2 ** 10 ,
@@ -243,8 +245,6 @@ def simulate(pars):
243245 rdata = amici .runAmiciSimulation (amici_model , solver_amici )
244246
245247 np .testing .assert_allclose (x , rdata ["x" ], rtol = rtol , atol = atol )
246- np .testing .assert_allclose (
247- sx , rdata ["sx" ], rtol = rtol * tol_factor , atol = atol * tol_factor
248- )
248+ np .testing .assert_allclose (sx , rdata ["sx" ], rtol = rtol , atol = atol )
249249 finally :
250250 shutil .rmtree (model_dir , ignore_errors = True )
0 commit comments