|
32 | 32 | "outputs": [], |
33 | 33 | "source": [ |
34 | 34 | "import petab.v1 as petab\n", |
35 | | - "from amici.importers.petab.v1 import import_petab_problem\n", |
| 35 | + "from amici.importers.petab import *\n", |
| 36 | + "from petab.v2 import Problem\n", |
36 | 37 | "\n", |
37 | 38 | "# Define the model name and YAML file location\n", |
38 | 39 | "model_name = \"Boehm_JProteomeRes2014\"\n", |
|
41 | 42 | " f\"master/Benchmark-Models/{model_name}/{model_name}.yaml\"\n", |
42 | 43 | ")\n", |
43 | 44 | "\n", |
44 | | - "# Load the PEtab problem from the YAML file\n", |
45 | | - "petab_problem = petab.Problem.from_yaml(yaml_url)\n", |
| 45 | + "# Load the PEtab problem from the YAML file as a PEtab v2 problem\n", |
| 46 | + "# (the JAX backend only supports PEtab v2)\n", |
| 47 | + "petab_problem = Problem.from_yaml(yaml_url)\n", |
46 | 48 | "\n", |
47 | 49 | "# Import the PEtab problem as a JAX-compatible AMICI problem\n", |
48 | | - "jax_problem = import_petab_problem(\n", |
49 | | - " petab_problem,\n", |
50 | | - " verbose=False, # no text output\n", |
51 | | - " jax=True, # return jax problem\n", |
| 50 | + "pi = PetabImporter(\n", |
| 51 | + " petab_problem=petab_problem,\n", |
| 52 | + " module_name=model_name,\n", |
| 53 | + " compile_=True,\n", |
| 54 | + " jax=True,\n", |
| 55 | + ")\n", |
| 56 | + "\n", |
| 57 | + "jax_problem = pi.create_simulator(\n", |
| 58 | + " force_import=True,\n", |
52 | 59 | ")" |
53 | 60 | ] |
54 | 61 | }, |
|
75 | 82 | "llh, results = run_simulations(jax_problem)" |
76 | 83 | ] |
77 | 84 | }, |
| 85 | + { |
| 86 | + "cell_type": "code", |
| 87 | + "execution_count": null, |
| 88 | + "id": "6c5b2980-13f0-42e9-b13e-0fce05793910", |
| 89 | + "metadata": {}, |
| 90 | + "outputs": [], |
| 91 | + "source": [ |
| 92 | + "results" |
| 93 | + ] |
| 94 | + }, |
78 | 95 | { |
79 | 96 | "cell_type": "markdown", |
80 | 97 | "id": "415962751301c64a", |
|
90 | 107 | "metadata": {}, |
91 | 108 | "outputs": [], |
92 | 109 | "source": [ |
93 | | - "# Define the simulation condition\n", |
94 | | - "simulation_condition = (\"model1_data1\",)\n", |
| 110 | + "# # Define the simulation condition\n", |
| 111 | + "experiment_condition = \"_petab_experiment_condition___default__\"\n", |
95 | 112 | "\n", |
96 | | - "# Access the results for the specified condition\n", |
97 | | - "ic = results[\"simulation_conditions\"].index(simulation_condition)\n", |
| 113 | + "# # Access the results for the specified condition\n", |
| 114 | + "ic = results[\"dynamic_conditions\"].index(experiment_condition)\n", |
98 | 115 | "print(\"llh: \", results[\"llh\"][ic])\n", |
99 | 116 | "print(\"state variables: \", results[\"x\"][ic, :])" |
100 | 117 | ] |
|
146 | 163 | "import matplotlib.pyplot as plt\n", |
147 | 164 | "import numpy as np\n", |
148 | 165 | "\n", |
149 | | - "# Define the simulation condition\n", |
150 | | - "simulation_condition = (\"model1_data1\",)\n", |
| 166 | + "# Define the experiment condition\n", |
| 167 | + "experiment_condition = \"_petab_experiment_condition___default__\"\n", |
151 | 168 | "\n", |
152 | 169 | "\n", |
153 | 170 | "def plot_simulation(results):\n", |
|
158 | 175 | " results (dict): Simulation results from run_simulations.\n", |
159 | 176 | " \"\"\"\n", |
160 | 177 | " # Extract the simulation results for the specific condition\n", |
161 | | - " ic = results[\"simulation_conditions\"].index(simulation_condition)\n", |
| 178 | + " ic = results[\"dynamic_conditions\"].index(experiment_condition)\n", |
162 | 179 | "\n", |
163 | 180 | " # Create a new figure for the state trajectories\n", |
164 | 181 | " plt.figure(figsize=(8, 6))\n", |
|
172 | 189 | " # Add labels, legend, and grid\n", |
173 | 190 | " plt.xlabel(\"Time\")\n", |
174 | 191 | " plt.ylabel(\"State Values\")\n", |
175 | | - " plt.title(simulation_condition)\n", |
| 192 | + " plt.title(experiment_condition)\n", |
176 | 193 | " plt.legend()\n", |
177 | 194 | " plt.grid(True)\n", |
178 | 195 | " plt.show()\n", |
|
187 | 204 | "id": "4fa97c33719c2277", |
188 | 205 | "metadata": {}, |
189 | 206 | "source": [ |
190 | | - "`run_simulations` enables users to specify the simulation conditions to be executed. For more complex models, this allows for restricting simulations to a subset of conditions. Since the Böhm model includes only a single condition, we demonstrate this functionality by simulating no condition at all." |
191 | | - ] |
192 | | - }, |
193 | | - { |
194 | | - "cell_type": "code", |
195 | | - "execution_count": null, |
196 | | - "id": "7950774a3e989042", |
197 | | - "metadata": {}, |
198 | | - "outputs": [], |
199 | | - "source": [ |
200 | | - "llh, results = run_simulations(jax_problem, simulation_conditions=tuple())\n", |
201 | | - "results" |
| 207 | + "`run_simulations` enables users to specify the simulation experiments to be executed. For more complex models, this allows for restricting simulations to a subset of experiments by passing a tuple of experiment ids under the keyword `simulation_experiments` to `run_simulations`." |
202 | 208 | ] |
203 | 209 | }, |
204 | 210 | { |
|
384 | 390 | "from amici.jax import ReturnValue\n", |
385 | 391 | "\n", |
386 | 392 | "# Define the simulation condition\n", |
387 | | - "simulation_condition = (\"model1_data1\",)\n", |
388 | | - "ic = jax_problem.simulation_conditions.index(simulation_condition)\n", |
| 393 | + "experiment_condition = \"_petab_experiment_condition___default__\"\n", |
| 394 | + "ic = 0\n", |
389 | 395 | "\n", |
390 | 396 | "# Load condition-specific data\n", |
391 | 397 | "ts_dyn = jax_problem._ts_dyn[ic, :]\n", |
|
397 | 403 | "nps = jax_problem._np_numeric[ic, :]\n", |
398 | 404 | "\n", |
399 | 405 | "# Load parameters for the specified condition\n", |
400 | | - "p = jax_problem.load_model_parameters(simulation_condition[0])\n", |
| 406 | + "p = jax_problem.load_model_parameters(jax_problem._petab_problem.experiments[0], is_preeq=False)\n", |
401 | 407 | "\n", |
402 | 408 | "\n", |
403 | 409 | "# Define a function to compute the gradient with respect to dynamic timepoints\n", |
|
431 | 437 | "cell_type": "markdown", |
432 | 438 | "id": "19ca88c8900584ce", |
433 | 439 | "metadata": {}, |
434 | | - "source": "## Model training" |
| 440 | + "source": [ |
| 441 | + "## Model training" |
| 442 | + ] |
435 | 443 | }, |
436 | 444 | { |
437 | 445 | "cell_type": "markdown", |
438 | 446 | "id": "7f99c046d7d4e225", |
439 | 447 | "metadata": {}, |
440 | | - "source": "This setup makes it pretty straightforward to train models using [equinox](https://docs.kidger.site/equinox/) and [optax](https://optax.readthedocs.io/en/latest/) frameworks. Below we provide barebones implementation that runs training for 5 steps using Adam." |
| 448 | + "source": [ |
| 449 | + "This setup makes it pretty straightforward to train models using [equinox](https://docs.kidger.site/equinox/) and [optax](https://optax.readthedocs.io/en/latest/) frameworks. Below we provide barebones implementation that runs training for 5 steps using Adam." |
| 450 | + ] |
441 | 451 | }, |
442 | 452 | { |
443 | 453 | "cell_type": "code", |
|
569 | 579 | "from amici.sim.sundials.petab.v1 import simulate_petab\n", |
570 | 580 | "\n", |
571 | 581 | "# Import the PEtab problem as a standard AMICI model\n", |
572 | | - "amici_model = import_petab_problem(\n", |
573 | | - " petab_problem,\n", |
574 | | - " verbose=False,\n", |
575 | | - " jax=False, # load the amici model this time\n", |
| 582 | + "pi = PetabImporter(\n", |
| 583 | + " petab_problem=petab_problem,\n", |
| 584 | + " module_name=model_name,\n", |
| 585 | + " compile_=True,\n", |
| 586 | + " jax=False,\n", |
| 587 | + ")\n", |
| 588 | + "\n", |
| 589 | + "amici_model = pi.create_simulator(\n", |
| 590 | + " force_import=True,\n", |
576 | 591 | ")\n", |
577 | 592 | "\n", |
578 | 593 | "# Configure the solver with appropriate tolerances\n", |
579 | | - "solver = amici_model.create_solver()\n", |
580 | | - "solver.set_absolute_tolerance(1e-8)\n", |
581 | | - "solver.set_relative_tolerance(1e-16)\n", |
| 594 | + "amici_model.solver.set_absolute_tolerance(1e-8)\n", |
| 595 | + "amici_model.solver.set_relative_tolerance(1e-16)\n", |
582 | 596 | "\n", |
583 | 597 | "# Prepare the parameters for the simulation\n", |
584 | 598 | "problem_parameters = dict(\n", |
|
594 | 608 | "outputs": [], |
595 | 609 | "source": [ |
596 | 610 | "# Profile simulation only\n", |
597 | | - "solver.set_sensitivity_order(SensitivityOrder.none)" |
| 611 | + "amici_model.solver.set_sensitivity_order(SensitivityOrder.none)" |
598 | 612 | ] |
599 | 613 | }, |
600 | 614 | { |
601 | | - "metadata": {}, |
602 | 615 | "cell_type": "code", |
603 | | - "outputs": [], |
604 | 616 | "execution_count": null, |
| 617 | + "id": "42cbc67bc09b67dc", |
| 618 | + "metadata": {}, |
| 619 | + "outputs": [], |
605 | 620 | "source": [ |
606 | 621 | "%%timeit\n", |
607 | | - "simulate_petab(\n", |
608 | | - " petab_problem,\n", |
609 | | - " amici_model=amici_model,\n", |
610 | | - " solver=solver,\n", |
611 | | - " problem_parameters=problem_parameters,\n", |
612 | | - " scaled_parameters=True,\n", |
613 | | - " scaled_gradients=True,\n", |
614 | | - ")" |
615 | | - ], |
616 | | - "id": "42cbc67bc09b67dc" |
| 622 | + "amici_model.simulate(petab_problem.get_x_nominal_dict())" |
| 623 | + ] |
617 | 624 | }, |
618 | 625 | { |
619 | | - "metadata": {}, |
620 | 626 | "cell_type": "code", |
621 | | - "outputs": [], |
622 | 627 | "execution_count": null, |
| 628 | + "id": "4f1c06c5893a9c07", |
| 629 | + "metadata": {}, |
| 630 | + "outputs": [], |
623 | 631 | "source": [ |
624 | 632 | "# Profile gradient computation using forward sensitivity analysis\n", |
625 | | - "solver.set_sensitivity_order(SensitivityOrder.first)\n", |
626 | | - "solver.set_sensitivity_method(SensitivityMethod.forward)" |
627 | | - ], |
628 | | - "id": "4f1c06c5893a9c07" |
| 633 | + "amici_model.solver.set_sensitivity_order(SensitivityOrder.first)\n", |
| 634 | + "amici_model.solver.set_sensitivity_method(SensitivityMethod.forward)" |
| 635 | + ] |
629 | 636 | }, |
630 | 637 | { |
631 | | - "metadata": {}, |
632 | 638 | "cell_type": "code", |
633 | | - "outputs": [], |
634 | 639 | "execution_count": null, |
| 640 | + "id": "7367a19bcea98597", |
| 641 | + "metadata": {}, |
| 642 | + "outputs": [], |
635 | 643 | "source": [ |
636 | 644 | "%%timeit\n", |
637 | | - "simulate_petab(\n", |
638 | | - " petab_problem,\n", |
639 | | - " amici_model=amici_model,\n", |
640 | | - " solver=solver,\n", |
641 | | - " problem_parameters=problem_parameters,\n", |
642 | | - " scaled_parameters=True,\n", |
643 | | - " scaled_gradients=True,\n", |
644 | | - ")" |
645 | | - ], |
646 | | - "id": "7367a19bcea98597" |
| 645 | + "amici_model.simulate(petab_problem.get_x_nominal_dict())" |
| 646 | + ] |
647 | 647 | }, |
648 | 648 | { |
649 | | - "metadata": {}, |
650 | 649 | "cell_type": "code", |
651 | | - "outputs": [], |
652 | 650 | "execution_count": null, |
| 651 | + "id": "a31e8eda806c2d7", |
| 652 | + "metadata": {}, |
| 653 | + "outputs": [], |
653 | 654 | "source": [ |
654 | 655 | "# Profile gradient computation using adjoint sensitivity analysis\n", |
655 | | - "solver.set_sensitivity_order(SensitivityOrder.first)\n", |
656 | | - "solver.set_sensitivity_method(SensitivityMethod.adjoint)" |
657 | | - ], |
658 | | - "id": "a31e8eda806c2d7" |
| 656 | + "amici_model.solver.set_sensitivity_order(SensitivityOrder.first)\n", |
| 657 | + "amici_model.solver.set_sensitivity_method(SensitivityMethod.adjoint)" |
| 658 | + ] |
659 | 659 | }, |
660 | 660 | { |
661 | | - "metadata": {}, |
662 | 661 | "cell_type": "code", |
663 | | - "outputs": [], |
664 | 662 | "execution_count": null, |
| 663 | + "id": "3f2ab1acb3ba818f", |
| 664 | + "metadata": {}, |
| 665 | + "outputs": [], |
665 | 666 | "source": [ |
666 | 667 | "%%timeit\n", |
667 | | - "simulate_petab(\n", |
668 | | - " petab_problem,\n", |
669 | | - " amici_model=amici_model,\n", |
670 | | - " solver=solver,\n", |
671 | | - " problem_parameters=problem_parameters,\n", |
672 | | - " scaled_parameters=True,\n", |
673 | | - " scaled_gradients=True,\n", |
674 | | - ")" |
675 | | - ], |
676 | | - "id": "3f2ab1acb3ba818f" |
| 668 | + "amici_model.simulate(petab_problem.get_x_nominal_dict())" |
| 669 | + ] |
677 | 670 | } |
678 | 671 | ], |
679 | 672 | "metadata": { |
|
691 | 684 | "mimetype": "text/x-python", |
692 | 685 | "name": "python", |
693 | 686 | "nbconvert_exporter": "python", |
694 | | - "pygments_lexer": "ipython3" |
| 687 | + "pygments_lexer": "ipython3", |
| 688 | + "version": "3.12.3" |
695 | 689 | } |
696 | 690 | }, |
697 | 691 | "nbformat": 4, |
|
0 commit comments