Skip to content

Commit b2699a8

Browse files
Add support in JAX backend for PEtab v2 (#3115)
* event assignments jax - sbml cases 348 - 404 * fix up sbml test cases - not implemented priority, update t_eps, fix heaviside init * initialValue False not implemented * try fix other test cases * params only in explicit triggers - and matrix only in JAX again * oops committed breakpoint * looking for initialValue test cases support explicit initial value event assignments fix deltax tcl * do not update h pre-solve * handle_t0_event * reinstate time skip (hack diffrax bug?) * Update python/sdist/amici/jax/_simulation.py Co-authored-by: Fabian Fröhlich <fabian.frohlich@crick.ac.uk> * Revert "Update python/sdist/amici/jax/_simulation.py" This reverts commit 82caa94. * rm clip controller * handle t0 event near zero * skip non-time dependent event assignment cases * first pass petabv2 - updating JAXProblem init * petabv2 test cases up to 15-ish * petab v2 test cases up to 27-ish * rework petabv2 jax test cases with ExperimentsToSbmlEvents and no v1 parameter_mapping * fix some rebase issues * add test skip for petab v1 * remaining petab v2 test cases * update workflows - deactivate petab_sciml wf for now * update tests to skip on petab v2 type error * skip some more tests and reinstate more specific implicit triggers check * fixup benchmark skipping check * tidying - add docstrings - rm outputs in notebook * skip implicit benchmark case - and prior distribution cases * review feedback - rm simultaneous event check - implement sequential event assignments with assumed priority * fix test_jax tests and sbml cases with no y0 * fix pysb test case for jax * use h symbol and update example notebook * rm v1 instance checks and improve preeq conditionals * implement implicit triggers using fixed parameters check * temp workaround for pysb build issue * pysb workaround in petab workflow too * fix indentation errors in gen JAX code - restore petabv1 conditional * skip JAXProblems with v1 problems * skip JAXProblems with v1 problems - again * check implicit triggers in sep function * use petabv2 constants - avoid df usage in loops * restore pysb install to master * tody petab file --------- Co-authored-by: Fabian Fröhlich <fabian.frohlich@crick.ac.uk>
1 parent 67ad11a commit b2699a8

19 files changed

Lines changed: 1014 additions & 513 deletions

File tree

.github/workflows/test_petab_sciml.yml

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
name: PEtab SciML
2-
on:
3-
push:
4-
branches:
5-
- main
6-
- 'release*'
7-
pull_request:
8-
branches:
9-
- main
10-
merge_group:
11-
workflow_dispatch:
2+
# on:
3+
# push:
4+
# branches:
5+
# - main
6+
# - 'release*'
7+
# pull_request:
8+
# branches:
9+
# - main
10+
# merge_group:
11+
# workflow_dispatch:
1212

1313
jobs:
1414
build:

.github/workflows/test_petab_test_suite.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ jobs:
172172
git clone https://github.com/PEtab-dev/petab_test_suite \
173173
&& source ./venv/bin/activate \
174174
&& cd petab_test_suite \
175-
&& git checkout c12b9dc4e4c5585b1b83a1d6e89fd22447c46d03 \
175+
&& git checkout 9542847fb99bcbdffc236e2ef45ba90580a210fa \
176176
&& pip3 install -e .
177177
178178
# TODO: once there is a PEtab v2 benchmark collection
@@ -186,7 +186,7 @@ jobs:
186186
run: |
187187
source ./venv/bin/activate \
188188
&& python3 -m pip uninstall -y petab \
189-
&& python3 -m pip install git+https://github.com/petab-dev/libpetab-python.git@8dc6c1c4b801fba5acc35fcd25308a659d01050e \
189+
&& python3 -m pip install git+https://github.com/petab-dev/libpetab-python.git@d57d9fed8d8d5f8592e76d0b15676e05397c3b4b \
190190
&& python3 -m pip install git+https://github.com/pysb/pysb@master \
191191
&& python3 -m pip install sympy>=1.12.1
192192

doc/examples/example_jax_petab/ExampleJaxPEtab.ipynb

Lines changed: 83 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@
3232
"outputs": [],
3333
"source": [
3434
"import petab.v1 as petab\n",
35-
"from amici.importers.petab.v1 import import_petab_problem\n",
35+
"from amici.importers.petab import *\n",
36+
"from petab.v2 import Problem\n",
3637
"\n",
3738
"# Define the model name and YAML file location\n",
3839
"model_name = \"Boehm_JProteomeRes2014\"\n",
@@ -41,14 +42,20 @@
4142
" f\"master/Benchmark-Models/{model_name}/{model_name}.yaml\"\n",
4243
")\n",
4344
"\n",
44-
"# Load the PEtab problem from the YAML file\n",
45-
"petab_problem = petab.Problem.from_yaml(yaml_url)\n",
45+
"# Load the PEtab problem from the YAML file as a PEtab v2 problem\n",
46+
"# (the JAX backend only supports PEtab v2)\n",
47+
"petab_problem = Problem.from_yaml(yaml_url)\n",
4648
"\n",
4749
"# Import the PEtab problem as a JAX-compatible AMICI problem\n",
48-
"jax_problem = import_petab_problem(\n",
49-
" petab_problem,\n",
50-
" verbose=False, # no text output\n",
51-
" jax=True, # return jax problem\n",
50+
"pi = PetabImporter(\n",
51+
" petab_problem=petab_problem,\n",
52+
" module_name=model_name,\n",
53+
" compile_=True,\n",
54+
" jax=True,\n",
55+
")\n",
56+
"\n",
57+
"jax_problem = pi.create_simulator(\n",
58+
" force_import=True,\n",
5259
")"
5360
]
5461
},
@@ -75,6 +82,16 @@
7582
"llh, results = run_simulations(jax_problem)"
7683
]
7784
},
85+
{
86+
"cell_type": "code",
87+
"execution_count": null,
88+
"id": "6c5b2980-13f0-42e9-b13e-0fce05793910",
89+
"metadata": {},
90+
"outputs": [],
91+
"source": [
92+
"results"
93+
]
94+
},
7895
{
7996
"cell_type": "markdown",
8097
"id": "415962751301c64a",
@@ -90,11 +107,11 @@
90107
"metadata": {},
91108
"outputs": [],
92109
"source": [
93-
"# Define the simulation condition\n",
94-
"simulation_condition = (\"model1_data1\",)\n",
110+
"# # Define the simulation condition\n",
111+
"experiment_condition = \"_petab_experiment_condition___default__\"\n",
95112
"\n",
96-
"# Access the results for the specified condition\n",
97-
"ic = results[\"simulation_conditions\"].index(simulation_condition)\n",
113+
"# # Access the results for the specified condition\n",
114+
"ic = results[\"dynamic_conditions\"].index(experiment_condition)\n",
98115
"print(\"llh: \", results[\"llh\"][ic])\n",
99116
"print(\"state variables: \", results[\"x\"][ic, :])"
100117
]
@@ -146,8 +163,8 @@
146163
"import matplotlib.pyplot as plt\n",
147164
"import numpy as np\n",
148165
"\n",
149-
"# Define the simulation condition\n",
150-
"simulation_condition = (\"model1_data1\",)\n",
166+
"# Define the experiment condition\n",
167+
"experiment_condition = \"_petab_experiment_condition___default__\"\n",
151168
"\n",
152169
"\n",
153170
"def plot_simulation(results):\n",
@@ -158,7 +175,7 @@
158175
" results (dict): Simulation results from run_simulations.\n",
159176
" \"\"\"\n",
160177
" # Extract the simulation results for the specific condition\n",
161-
" ic = results[\"simulation_conditions\"].index(simulation_condition)\n",
178+
" ic = results[\"dynamic_conditions\"].index(experiment_condition)\n",
162179
"\n",
163180
" # Create a new figure for the state trajectories\n",
164181
" plt.figure(figsize=(8, 6))\n",
@@ -172,7 +189,7 @@
172189
" # Add labels, legend, and grid\n",
173190
" plt.xlabel(\"Time\")\n",
174191
" plt.ylabel(\"State Values\")\n",
175-
" plt.title(simulation_condition)\n",
192+
" plt.title(experiment_condition)\n",
176193
" plt.legend()\n",
177194
" plt.grid(True)\n",
178195
" plt.show()\n",
@@ -187,18 +204,7 @@
187204
"id": "4fa97c33719c2277",
188205
"metadata": {},
189206
"source": [
190-
"`run_simulations` enables users to specify the simulation conditions to be executed. For more complex models, this allows for restricting simulations to a subset of conditions. Since the Böhm model includes only a single condition, we demonstrate this functionality by simulating no condition at all."
191-
]
192-
},
193-
{
194-
"cell_type": "code",
195-
"execution_count": null,
196-
"id": "7950774a3e989042",
197-
"metadata": {},
198-
"outputs": [],
199-
"source": [
200-
"llh, results = run_simulations(jax_problem, simulation_conditions=tuple())\n",
201-
"results"
207+
"`run_simulations` enables users to specify the simulation experiments to be executed. For more complex models, this allows for restricting simulations to a subset of experiments by passing a tuple of experiment ids under the keyword `simulation_experiments` to `run_simulations`."
202208
]
203209
},
204210
{
@@ -384,8 +390,8 @@
384390
"from amici.jax import ReturnValue\n",
385391
"\n",
386392
"# Define the simulation condition\n",
387-
"simulation_condition = (\"model1_data1\",)\n",
388-
"ic = jax_problem.simulation_conditions.index(simulation_condition)\n",
393+
"experiment_condition = \"_petab_experiment_condition___default__\"\n",
394+
"ic = 0\n",
389395
"\n",
390396
"# Load condition-specific data\n",
391397
"ts_dyn = jax_problem._ts_dyn[ic, :]\n",
@@ -397,7 +403,7 @@
397403
"nps = jax_problem._np_numeric[ic, :]\n",
398404
"\n",
399405
"# Load parameters for the specified condition\n",
400-
"p = jax_problem.load_model_parameters(simulation_condition[0])\n",
406+
"p = jax_problem.load_model_parameters(jax_problem._petab_problem.experiments[0], is_preeq=False)\n",
401407
"\n",
402408
"\n",
403409
"# Define a function to compute the gradient with respect to dynamic timepoints\n",
@@ -431,13 +437,17 @@
431437
"cell_type": "markdown",
432438
"id": "19ca88c8900584ce",
433439
"metadata": {},
434-
"source": "## Model training"
440+
"source": [
441+
"## Model training"
442+
]
435443
},
436444
{
437445
"cell_type": "markdown",
438446
"id": "7f99c046d7d4e225",
439447
"metadata": {},
440-
"source": "This setup makes it pretty straightforward to train models using [equinox](https://docs.kidger.site/equinox/) and [optax](https://optax.readthedocs.io/en/latest/) frameworks. Below we provide barebones implementation that runs training for 5 steps using Adam."
448+
"source": [
449+
"This setup makes it pretty straightforward to train models using [equinox](https://docs.kidger.site/equinox/) and [optax](https://optax.readthedocs.io/en/latest/) frameworks. Below we provide barebones implementation that runs training for 5 steps using Adam."
450+
]
441451
},
442452
{
443453
"cell_type": "code",
@@ -569,16 +579,20 @@
569579
"from amici.sim.sundials.petab.v1 import simulate_petab\n",
570580
"\n",
571581
"# Import the PEtab problem as a standard AMICI model\n",
572-
"amici_model = import_petab_problem(\n",
573-
" petab_problem,\n",
574-
" verbose=False,\n",
575-
" jax=False, # load the amici model this time\n",
582+
"pi = PetabImporter(\n",
583+
" petab_problem=petab_problem,\n",
584+
" module_name=model_name,\n",
585+
" compile_=True,\n",
586+
" jax=False,\n",
587+
")\n",
588+
"\n",
589+
"amici_model = pi.create_simulator(\n",
590+
" force_import=True,\n",
576591
")\n",
577592
"\n",
578593
"# Configure the solver with appropriate tolerances\n",
579-
"solver = amici_model.create_solver()\n",
580-
"solver.set_absolute_tolerance(1e-8)\n",
581-
"solver.set_relative_tolerance(1e-16)\n",
594+
"amici_model.solver.set_absolute_tolerance(1e-8)\n",
595+
"amici_model.solver.set_relative_tolerance(1e-16)\n",
582596
"\n",
583597
"# Prepare the parameters for the simulation\n",
584598
"problem_parameters = dict(\n",
@@ -594,86 +608,65 @@
594608
"outputs": [],
595609
"source": [
596610
"# Profile simulation only\n",
597-
"solver.set_sensitivity_order(SensitivityOrder.none)"
611+
"amici_model.solver.set_sensitivity_order(SensitivityOrder.none)"
598612
]
599613
},
600614
{
601-
"metadata": {},
602615
"cell_type": "code",
603-
"outputs": [],
604616
"execution_count": null,
617+
"id": "42cbc67bc09b67dc",
618+
"metadata": {},
619+
"outputs": [],
605620
"source": [
606621
"%%timeit\n",
607-
"simulate_petab(\n",
608-
" petab_problem,\n",
609-
" amici_model=amici_model,\n",
610-
" solver=solver,\n",
611-
" problem_parameters=problem_parameters,\n",
612-
" scaled_parameters=True,\n",
613-
" scaled_gradients=True,\n",
614-
")"
615-
],
616-
"id": "42cbc67bc09b67dc"
622+
"amici_model.simulate(petab_problem.get_x_nominal_dict())"
623+
]
617624
},
618625
{
619-
"metadata": {},
620626
"cell_type": "code",
621-
"outputs": [],
622627
"execution_count": null,
628+
"id": "4f1c06c5893a9c07",
629+
"metadata": {},
630+
"outputs": [],
623631
"source": [
624632
"# Profile gradient computation using forward sensitivity analysis\n",
625-
"solver.set_sensitivity_order(SensitivityOrder.first)\n",
626-
"solver.set_sensitivity_method(SensitivityMethod.forward)"
627-
],
628-
"id": "4f1c06c5893a9c07"
633+
"amici_model.solver.set_sensitivity_order(SensitivityOrder.first)\n",
634+
"amici_model.solver.set_sensitivity_method(SensitivityMethod.forward)"
635+
]
629636
},
630637
{
631-
"metadata": {},
632638
"cell_type": "code",
633-
"outputs": [],
634639
"execution_count": null,
640+
"id": "7367a19bcea98597",
641+
"metadata": {},
642+
"outputs": [],
635643
"source": [
636644
"%%timeit\n",
637-
"simulate_petab(\n",
638-
" petab_problem,\n",
639-
" amici_model=amici_model,\n",
640-
" solver=solver,\n",
641-
" problem_parameters=problem_parameters,\n",
642-
" scaled_parameters=True,\n",
643-
" scaled_gradients=True,\n",
644-
")"
645-
],
646-
"id": "7367a19bcea98597"
645+
"amici_model.simulate(petab_problem.get_x_nominal_dict())"
646+
]
647647
},
648648
{
649-
"metadata": {},
650649
"cell_type": "code",
651-
"outputs": [],
652650
"execution_count": null,
651+
"id": "a31e8eda806c2d7",
652+
"metadata": {},
653+
"outputs": [],
653654
"source": [
654655
"# Profile gradient computation using adjoint sensitivity analysis\n",
655-
"solver.set_sensitivity_order(SensitivityOrder.first)\n",
656-
"solver.set_sensitivity_method(SensitivityMethod.adjoint)"
657-
],
658-
"id": "a31e8eda806c2d7"
656+
"amici_model.solver.set_sensitivity_order(SensitivityOrder.first)\n",
657+
"amici_model.solver.set_sensitivity_method(SensitivityMethod.adjoint)"
658+
]
659659
},
660660
{
661-
"metadata": {},
662661
"cell_type": "code",
663-
"outputs": [],
664662
"execution_count": null,
663+
"id": "3f2ab1acb3ba818f",
664+
"metadata": {},
665+
"outputs": [],
665666
"source": [
666667
"%%timeit\n",
667-
"simulate_petab(\n",
668-
" petab_problem,\n",
669-
" amici_model=amici_model,\n",
670-
" solver=solver,\n",
671-
" problem_parameters=problem_parameters,\n",
672-
" scaled_parameters=True,\n",
673-
" scaled_gradients=True,\n",
674-
")"
675-
],
676-
"id": "3f2ab1acb3ba818f"
668+
"amici_model.simulate(petab_problem.get_x_nominal_dict())"
669+
]
677670
}
678671
],
679672
"metadata": {
@@ -691,7 +684,8 @@
691684
"mimetype": "text/x-python",
692685
"name": "python",
693686
"nbconvert_exporter": "python",
694-
"pygments_lexer": "ipython3"
687+
"pygments_lexer": "ipython3",
688+
"version": "3.12.3"
695689
}
696690
},
697691
"nbformat": 4,

python/sdist/amici/_symbolic/de_model.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2679,11 +2679,16 @@ def has_priority_events(self) -> bool:
26792679
def has_implicit_event_assignments(self) -> bool:
26802680
"""
26812681
Checks whether the model has event assignments with implicit triggers
2682+
(i.e. triggers that are not time based).
26822683
26832684
:return:
26842685
boolean indicating if event assignments with implicit triggers are present
26852686
"""
2686-
return any(event.updates_state and not event.has_explicit_trigger_times({}) for event in self._events)
2687+
fixed_symbols = set([k._symbol for k in self._fixed_parameters])
2688+
allowed_symbols = fixed_symbols | {amici_time_symbol}
2689+
# TODO: update to use has_explicit_trigger_times once
2690+
# https://github.com/AMICI-dev/AMICI/issues/3126 is resolved
2691+
return any(event.updates_state and event._has_implicit_triggers(allowed_symbols) for event in self._events)
26872692

26882693
def toposort_expressions(
26892694
self, reorder: bool = True

0 commit comments

Comments
 (0)