Skip to content

Commit 947aba7

Browse files
authored
Fixup sbml testsuite gradient check (#2858)
* Update testSBMLSuite.py * Update testSBMLSuite.py
1 parent ac676ef commit 947aba7

1 file changed

Lines changed: 9 additions & 9 deletions

File tree

tests/sbml/testSBMLSuite.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import shutil
1414
from pathlib import Path
1515

16+
import optimistix
17+
1618
import amici
1719
import pandas as pd
1820
import pytest
@@ -21,7 +23,10 @@
2123
import jax.numpy as jnp
2224
import numpy as np
2325
import diffrax
24-
from amici.jax.petab import DEFAULT_CONTROLLER_SETTINGS
26+
from amici.jax.petab import (
27+
DEFAULT_CONTROLLER_SETTINGS,
28+
DEFAULT_ROOT_FINDER_SETTINGS,
29+
)
2530

2631
from utils import (
2732
verify_results,
@@ -72,11 +77,6 @@ def test_sbml_testsuite_case(test_id, result_path, sbml_semantic_cases_dir):
7277

7378
atol, rtol = apply_settings(settings, solver, model, test_id)
7479

75-
if test_id in sensitivity_check_cases:
76-
model.requireSensitivitiesForAllParameters()
77-
solver.setSensitivityOrder(amici.SensitivityOrder.first)
78-
solver.setSensitivityMethod(amici.SensitivityMethod.forward)
79-
8080
# simulate model
8181
rdata = amici.runAmiciSimulation(model, solver)
8282
if rdata["status"] != amici.AMICI_SUCCESS:
@@ -208,6 +208,7 @@ def jax_sensitivity_check(
208208
icoeff=DEFAULT_CONTROLLER_SETTINGS["icoeff"],
209209
dcoeff=DEFAULT_CONTROLLER_SETTINGS["dcoeff"],
210210
)
211+
root_finder = optimistix.Newton(**DEFAULT_ROOT_FINDER_SETTINGS)
211212

212213
def simulate(pars):
213214
x, _ = jax_model.simulate_condition(
@@ -221,6 +222,7 @@ def simulate(pars):
221222
jnp.zeros((ts_jnp.shape[0], 0)),
222223
solver,
223224
controller,
225+
root_finder,
224226
diffrax.DirectAdjoint(),
225227
diffrax.SteadyStateEvent(),
226228
2**10,
@@ -243,8 +245,6 @@ def simulate(pars):
243245
rdata = amici.runAmiciSimulation(amici_model, solver_amici)
244246

245247
np.testing.assert_allclose(x, rdata["x"], rtol=rtol, atol=atol)
246-
np.testing.assert_allclose(
247-
sx, rdata["sx"], rtol=rtol * tol_factor, atol=atol * tol_factor
248-
)
248+
np.testing.assert_allclose(sx, rdata["sx"], rtol=rtol, atol=atol)
249249
finally:
250250
shutil.rmtree(model_dir, ignore_errors=True)

0 commit comments

Comments
 (0)