|
46 | 46 | "output_type": "stream", |
47 | 47 | "text": [ |
48 | 48 | "Cloning into 'tmp/benchmark-models'...\n", |
49 | | - "remote: Enumerating objects: 336, done.\u001B[K\n", |
50 | | - "remote: Counting objects: 100% (336/336), done.\u001B[K\n", |
51 | | - "remote: Compressing objects: 100% (285/285), done.\u001B[K\n", |
52 | | - "remote: Total 336 (delta 88), reused 216 (delta 39), pack-reused 0\u001B[K\n", |
| 49 | + "remote: Enumerating objects: 336, done.\u001b[K\n", |
| 50 | + "remote: Counting objects: 100% (336/336), done.\u001b[K\n", |
| 51 | + "remote: Compressing objects: 100% (285/285), done.\u001b[K\n", |
| 52 | + "remote: Total 336 (delta 88), reused 216 (delta 39), pack-reused 0\u001b[K\n", |
53 | 53 | "Receiving objects: 100% (336/336), 2.11 MiB | 7.48 MiB/s, done.\n", |
54 | 54 | "Resolving deltas: 100% (88/88), done.\n" |
55 | 55 | ] |
|
58 | 58 | "source": [ |
59 | 59 | "!git clone --depth 1 https://github.com/Benchmarking-Initiative/Benchmark-Models-PEtab.git tmp/benchmark-models || (cd tmp/benchmark-models && git pull)\n", |
60 | 60 | "from pathlib import Path\n", |
61 | | - "folder_base = Path('.') / \"tmp\" / \"benchmark-models\" / \"Benchmark-Models\"" |
| 61 | + "\n", |
| 62 | + "folder_base = Path(\".\") / \"tmp\" / \"benchmark-models\" / \"Benchmark-Models\"" |
62 | 63 | ] |
63 | 64 | }, |
64 | 65 | { |
|
77 | 78 | "outputs": [], |
78 | 79 | "source": [ |
79 | 80 | "import petab\n", |
| 81 | + "\n", |
80 | 82 | "model_name = \"Boehm_JProteomeRes2014\"\n", |
81 | 83 | "yaml_file = folder_base / model_name / (model_name + \".yaml\")\n", |
82 | 84 | "petab_problem = petab.Problem.from_yaml(yaml_file)" |
|
570 | 572 | ], |
571 | 573 | "source": [ |
572 | 574 | "from amici.petab_import import import_petab_problem\n", |
| 575 | + "\n", |
573 | 576 | "amici_model = import_petab_problem(petab_problem, force_compile=True)" |
574 | 577 | ] |
575 | 578 | }, |
|
606 | 609 | "source": [ |
607 | 610 | "from amici.petab_objective import simulate_petab\n", |
608 | 611 | "import amici\n", |
| 612 | + "\n", |
609 | 613 | "amici_solver = amici_model.getSolver()\n", |
610 | 614 | "amici_solver.setSensitivityOrder(amici.SensitivityOrder.first)\n", |
611 | 615 | "\n", |
| 616 | + "\n", |
612 | 617 | "def amici_hcb_base(parameters: jnp.array):\n", |
613 | 618 | " return simulate_petab(\n", |
614 | | - " petab_problem, \n", |
615 | | - " amici_model, \n", |
616 | | - " problem_parameters=dict(zip(petab_problem.x_free_ids, parameters)), \n", |
| 619 | + " petab_problem,\n", |
| 620 | + " amici_model,\n", |
| 621 | + " problem_parameters=dict(zip(petab_problem.x_free_ids, parameters)),\n", |
617 | 622 | " scaled_parameters=True,\n", |
618 | 623 | " solver=amici_solver,\n", |
619 | 624 | " )" |
|
635 | 640 | "outputs": [], |
636 | 641 | "source": [ |
637 | 642 | "def amici_hcb_llh(parameters: jnp.array):\n", |
638 | | - " return amici_hcb_base(parameters)['llh']\n", |
| 643 | + " return amici_hcb_base(parameters)[\"llh\"]\n", |
| 644 | + "\n", |
639 | 645 | "\n", |
640 | 646 | "def amici_hcb_sllh(parameters: jnp.array):\n", |
641 | | - " sllh = amici_hcb_base(parameters)['sllh']\n", |
642 | | - " return jnp.asarray(tuple(\n", |
643 | | - " sllh[par_id] for par_id in petab_problem.x_free_ids\n", |
644 | | - " ))" |
| 647 | + " sllh = amici_hcb_base(parameters)[\"sllh\"]\n", |
| 648 | + " return jnp.asarray(\n", |
| 649 | + " tuple(sllh[par_id] for par_id in petab_problem.x_free_ids)\n", |
| 650 | + " )" |
645 | 651 | ] |
646 | 652 | }, |
647 | 653 | { |
|
663 | 669 | "from jax import custom_jvp\n", |
664 | 670 | "\n", |
665 | 671 | "import numpy as np\n", |
| 672 | + "\n", |
| 673 | + "\n", |
666 | 674 | "@custom_jvp\n", |
667 | 675 | "def jax_objective(parameters: jnp.array):\n", |
668 | 676 | " return hcb.call(\n", |
|
695 | 703 | " sllh = hcb.call(\n", |
696 | 704 | " amici_hcb_sllh,\n", |
697 | 705 | " parameters,\n", |
698 | | - " result_shape=jax.ShapeDtypeStruct((petab_problem.parameter_df.estimate.sum(),), np.float64),\n", |
| 706 | + " result_shape=jax.ShapeDtypeStruct(\n", |
| 707 | + " (petab_problem.parameter_df.estimate.sum(),), np.float64\n", |
| 708 | + " ),\n", |
699 | 709 | " )\n", |
700 | 710 | " return llh, sllh.dot(x_dot)" |
701 | 711 | ] |
|
717 | 727 | "source": [ |
718 | 728 | "from jax import value_and_grad\n", |
719 | 729 | "\n", |
720 | | - "parameter_scales = petab_problem.parameter_df.loc[petab_problem.x_free_ids, petab.PARAMETER_SCALE].values\n", |
| 730 | + "parameter_scales = petab_problem.parameter_df.loc[\n", |
| 731 | + " petab_problem.x_free_ids, petab.PARAMETER_SCALE\n", |
| 732 | + "].values\n", |
| 733 | + "\n", |
721 | 734 | "\n", |
722 | 735 | "@jax.jit\n", |
723 | 736 | "@value_and_grad\n", |
724 | 737 | "def jax_objective_with_parameter_transform(parameters: jnp.array):\n", |
725 | | - " par_scaled = jnp.asarray(tuple(\n", |
726 | | - " value if scale == petab.LIN\n", |
727 | | - " else jnp.log(value) if scale == petab.LOG\n", |
728 | | - " else jnp.log10(value)\n", |
729 | | - " for value, scale in zip(parameters, parameter_scales)\n", |
730 | | - " ))\n", |
731 | | - " return jax_objective(par_scaled)\n", |
732 | | - " " |
| 738 | + " par_scaled = jnp.asarray(\n", |
| 739 | + " tuple(\n", |
| 740 | + " value\n", |
| 741 | + " if scale == petab.LIN\n", |
| 742 | + " else jnp.log(value)\n", |
| 743 | + " if scale == petab.LOG\n", |
| 744 | + " else jnp.log10(value)\n", |
| 745 | + " for value, scale in zip(parameters, parameter_scales)\n", |
| 746 | + " )\n", |
| 747 | + " )\n", |
| 748 | + " return jax_objective(par_scaled)" |
733 | 749 | ] |
734 | 750 | }, |
735 | 751 | { |
|
755 | 771 | "metadata": {}, |
756 | 772 | "outputs": [], |
757 | 773 | "source": [ |
758 | | - "llh_jax, sllh_jax = jax_objective_with_parameter_transform(petab_problem.x_nominal_free)" |
| 774 | + "llh_jax, sllh_jax = jax_objective_with_parameter_transform(\n", |
| 775 | + " petab_problem.x_nominal_free\n", |
| 776 | + ")" |
759 | 777 | ] |
760 | 778 | }, |
761 | 779 | { |
|
777 | 795 | "# TODO remove me as soon as sllh in simulate_petab is fixed\n", |
778 | 796 | "sllh = {\n", |
779 | 797 | " name: value / (np.log(10) * par_value)\n", |
780 | | - " for (name, value), par_value in zip(r['sllh'].items(), petab_problem.x_nominal_free)\n", |
| 798 | + " for (name, value), par_value in zip(\n", |
| 799 | + " r[\"sllh\"].items(), petab_problem.x_nominal_free\n", |
| 800 | + " )\n", |
781 | 801 | "}" |
782 | 802 | ] |
783 | 803 | }, |
|
802 | 822 | ], |
803 | 823 | "source": [ |
804 | 824 | "import pandas as pd\n", |
805 | | - "pd.Series(dict(amici=r['llh'], jax=float(llh_jax)))" |
| 825 | + "\n", |
| 826 | + "pd.Series(dict(amici=r[\"llh\"], jax=float(llh_jax)))" |
806 | 827 | ] |
807 | 828 | }, |
808 | 829 | { |
|
905 | 926 | } |
906 | 927 | ], |
907 | 928 | "source": [ |
908 | | - "pd.DataFrame(index=sllh.keys(), data=dict(amici=sllh.values(), jax=np.asarray(sllh_jax)))" |
| 929 | + "pd.DataFrame(\n", |
| 930 | + " index=sllh.keys(), data=dict(amici=sllh.values(), jax=np.asarray(sllh_jax))\n", |
| 931 | + ")" |
909 | 932 | ] |
910 | 933 | }, |
911 | 934 | { |
|
925 | 948 | "outputs": [], |
926 | 949 | "source": [ |
927 | 950 | "jax.config.update(\"jax_enable_x64\", True)\n", |
928 | | - "llh_jax, sllh_jax = jax_objective_with_parameter_transform(petab_problem.x_nominal_free)" |
| 951 | + "llh_jax, sllh_jax = jax_objective_with_parameter_transform(\n", |
| 952 | + " petab_problem.x_nominal_free\n", |
| 953 | + ")" |
929 | 954 | ] |
930 | 955 | }, |
931 | 956 | { |
|
956 | 981 | } |
957 | 982 | ], |
958 | 983 | "source": [ |
959 | | - "pd.Series(dict(amici=r['llh'], jax=float(llh_jax)))" |
| 984 | + "pd.Series(dict(amici=r[\"llh\"], jax=float(llh_jax)))" |
960 | 985 | ] |
961 | 986 | }, |
962 | 987 | { |
|
1059 | 1084 | } |
1060 | 1085 | ], |
1061 | 1086 | "source": [ |
1062 | | - "pd.DataFrame(index=sllh.keys(), data=dict(amici=sllh.values(), jax=np.asarray(sllh_jax)))" |
| 1087 | + "pd.DataFrame(\n", |
| 1088 | + " index=sllh.keys(), data=dict(amici=sllh.values(), jax=np.asarray(sllh_jax))\n", |
| 1089 | + ")" |
1063 | 1090 | ] |
1064 | 1091 | } |
1065 | 1092 | ], |
|
0 commit comments