Skip to content

Commit 27184ac

Browse files
authored
Fix JAX test failures related to diffrax/optimistix release (#3129)
* adjust test_jax tols * increase max steps for jax benchmarks * pin optax for notebook test * increase max steps for petab tests * increase petab max steps again * even higher max steps for Weber benchmark * pin optax in docs build
1 parent 1c5d990 commit 27184ac

5 files changed

Lines changed: 12 additions & 9 deletions

File tree

doc/rtd_requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ setuptools>=67.7.2
99
git+https://github.com/jmuhlich/pysb@22d69a350b472f33d85ba64ffb10b190483c1c98
1010
# For forward type definition in generate_equinox
1111
matplotlib>=3.7.1
12-
optax
12+
optax==0.2.6
1313
nbsphinx
1414
nbformat
1515
myst-parser

python/sdist/amici/jax/petab.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1550,7 +1550,7 @@ def run_simulations(
15501550
steady_state_event: Callable[
15511551
..., diffrax._custom_types.BoolScalarLike
15521552
] = diffrax.steady_state_event(),
1553-
max_steps: int = 2**10,
1553+
max_steps: int = 2**13,
15541554
ret: ReturnValue | str = ReturnValue.llh,
15551555
):
15561556
"""
@@ -1653,7 +1653,7 @@ def petab_simulate(
16531653
steady_state_event: Callable[
16541654
..., diffrax._custom_types.BoolScalarLike
16551655
] = diffrax.steady_state_event(),
1656-
max_steps: int = 2**10,
1656+
max_steps: int = 2**13,
16571657
):
16581658
"""
16591659
Run simulations for a problem and return the results as a petab simulation dataframe.

python/tests/test_jax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def check_fields_jax(
204204
"iy_trafos": jnp.array(iy_trafos),
205205
"x_preeq": jnp.array([]),
206206
"solver": diffrax.Kvaerno5(),
207-
"controller": diffrax.PIDController(atol=ATOL_SIM, rtol=RTOL_SIM),
207+
"controller": diffrax.PIDController(atol=1e-8, rtol=1e-8),
208208
"root_finder": optimistix.Newton(atol=ATOL_SIM, rtol=RTOL_SIM),
209209
"adjoint": diffrax.RecursiveCheckpointAdjoint(),
210210
"steady_state_event": diffrax.steady_state_event(),

scripts/installAmiciSource.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ python -m pip install --upgrade pip wheel
3939
python -m pip install --upgrade pip setuptools cmake_build_extension==0.6.0 numpy petab swig
4040
python -m pip install git+https://github.com/pysb/pysb@master # for SPM with compartments
4141
python -m pip install git+https://github.com/patrick-kidger/diffrax@main # for events with direction
42-
python -m pip install optax # for jax petab notebook
42+
python -m pip install 'optax<0.2.7' # for jax petab notebook
4343
AMICI_BUILD_TEMP="${AMICI_PATH}/python/sdist/build/temp" \
4444
python -m pip install --verbose -e "${AMICI_PATH}/python/sdist[petab,test,vis,jax]" --no-build-isolation
4545
deactivate

tests/benchmark_models/test_petab_benchmark_jax.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from amici.importers.petab.v1 import (
1111
import_petab_problem,
1212
)
13-
from amici.jax.petab import run_simulations
13+
from amici.jax.petab import run_simulations, DEFAULT_CONTROLLER_SETTINGS
1414
from amici.sim.sundials import SensitivityMethod, SensitivityOrder
1515
from amici.sim.sundials.petab.v1 import (
1616
LLH,
@@ -113,11 +113,11 @@ def test_jax_llh(benchmark_problem):
113113
if problem_id == "Weber_BMC2015":
114114
atol = cur_settings.atol_sim
115115
rtol = cur_settings.rtol_sim
116-
max_steps = 2 * 10**5
116+
max_steps = 4 * 10**7
117117
else:
118118
atol = 1e-8
119119
rtol = 1e-8
120-
max_steps = 1024
120+
max_steps = 2 * 10**5
121121
beartype(run_simulations)(jax_problem)
122122
(llh_jax, _), sllh_jax = eqx.filter_value_and_grad(
123123
run_simulations, has_aux=True
@@ -130,7 +130,10 @@ def test_jax_llh(benchmark_problem):
130130
),
131131
)
132132
else:
133-
llh_jax, _ = beartype(run_simulations)(jax_problem)
133+
llh_jax, _ = beartype(run_simulations)(
134+
jax_problem,
135+
max_steps=2 * 10**5,
136+
)
134137

135138
np.testing.assert_allclose(
136139
llh_jax,

0 commit comments

Comments
 (0)